use std::time::Duration;
use chrono::{DateTime, Utc};
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::llm::ChatMessage;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ActionRecord {
pub id: Uuid,
pub sequence: u32,
pub tool_name: String,
pub input: serde_json::Value,
pub output_raw: Option<String>,
pub output_sanitized: Option<serde_json::Value>,
pub sanitization_warnings: Vec<String>,
pub cost: Option<Decimal>,
pub duration: Duration,
pub success: bool,
pub error: Option<String>,
pub executed_at: DateTime<Utc>,
}
impl ActionRecord {
pub fn new(sequence: u32, tool_name: impl Into<String>, input: serde_json::Value) -> Self {
Self {
id: Uuid::new_v4(),
sequence,
tool_name: tool_name.into(),
input,
output_raw: None,
output_sanitized: None,
sanitization_warnings: Vec::new(),
cost: None,
duration: Duration::ZERO,
success: false,
error: None,
executed_at: Utc::now(),
}
}
pub fn succeed(
mut self,
output_raw: Option<String>,
output_sanitized: serde_json::Value,
duration: Duration,
) -> Self {
self.success = true;
self.output_raw = output_raw;
self.output_sanitized = Some(output_sanitized);
self.duration = duration;
self
}
pub fn fail(mut self, error: impl Into<String>, duration: Duration) -> Self {
self.success = false;
self.error = Some(error.into());
self.duration = duration;
self
}
pub fn with_warnings(mut self, warnings: Vec<String>) -> Self {
self.sanitization_warnings = warnings;
self
}
pub fn with_cost(mut self, cost: Decimal) -> Self {
self.cost = Some(cost);
self
}
}
#[derive(Debug, Clone, Default)]
pub struct ConversationMemory {
messages: Vec<ChatMessage>,
max_messages: usize,
}
impl ConversationMemory {
pub fn new(max_messages: usize) -> Self {
Self {
messages: Vec::new(),
max_messages,
}
}
pub fn add(&mut self, message: ChatMessage) {
self.messages.push(message);
while self.messages.len() > self.max_messages {
if self.messages.first().map(|m| m.role) == Some(crate::llm::Role::System) {
if self.messages.len() > 1 {
self.messages.remove(1);
} else {
break;
}
} else {
self.messages.remove(0);
}
}
}
pub fn messages(&self) -> &[ChatMessage] {
&self.messages
}
pub fn last_n(&self, n: usize) -> &[ChatMessage] {
let start = self.messages.len().saturating_sub(n);
&self.messages[start..]
}
pub fn clear(&mut self) {
self.messages.clear();
}
pub fn len(&self) -> usize {
self.messages.len()
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct Memory {
pub job_id: Uuid,
pub conversation: ConversationMemory,
pub actions: Vec<ActionRecord>,
next_sequence: u32,
}
impl Memory {
pub fn new(job_id: Uuid) -> Self {
Self {
job_id,
conversation: ConversationMemory::new(100),
actions: Vec::new(),
next_sequence: 0,
}
}
pub fn add_message(&mut self, message: ChatMessage) {
self.conversation.add(message);
}
pub fn create_action(
&mut self,
tool_name: impl Into<String>,
input: serde_json::Value,
) -> ActionRecord {
let seq = self.next_sequence;
self.next_sequence += 1;
ActionRecord::new(seq, tool_name, input)
}
pub fn record_action(&mut self, action: ActionRecord) {
self.actions.push(action);
}
pub fn total_cost(&self) -> Decimal {
self.actions
.iter()
.filter_map(|a| a.cost)
.fold(Decimal::ZERO, |acc, c| acc + c)
}
pub fn total_duration(&self) -> Duration {
self.actions
.iter()
.map(|a| a.duration)
.fold(Duration::ZERO, |acc, d| acc + d)
}
pub fn successful_actions(&self) -> usize {
self.actions.iter().filter(|a| a.success).count()
}
pub fn failed_actions(&self) -> usize {
self.actions.iter().filter(|a| !a.success).count()
}
pub fn last_action(&self) -> Option<&ActionRecord> {
self.actions.last()
}
pub fn actions_by_tool(&self, tool_name: &str) -> Vec<&ActionRecord> {
self.actions
.iter()
.filter(|a| a.tool_name == tool_name)
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_action_record() {
let action = ActionRecord::new(0, "test", serde_json::json!({"key": "value"}));
assert_eq!(action.sequence, 0);
assert!(!action.success);
let action = action.succeed(
Some("raw".to_string()),
serde_json::json!({"result": "ok"}),
Duration::from_millis(100),
);
assert!(action.success);
}
#[test]
fn test_conversation_memory() {
let mut memory = ConversationMemory::new(3);
memory.add(ChatMessage::user("Hello"));
memory.add(ChatMessage::assistant("Hi"));
memory.add(ChatMessage::user("How are you?"));
memory.add(ChatMessage::assistant("Good!"));
assert_eq!(memory.len(), 3); }
#[test]
fn test_memory_totals() {
let mut memory = Memory::new(Uuid::new_v4());
let action1 = memory
.create_action("tool1", serde_json::json!({}))
.succeed(None, serde_json::json!({}), Duration::from_secs(1))
.with_cost(Decimal::new(10, 1));
memory.record_action(action1);
let action2 = memory
.create_action("tool2", serde_json::json!({}))
.succeed(None, serde_json::json!({}), Duration::from_secs(2))
.with_cost(Decimal::new(20, 1));
memory.record_action(action2);
assert_eq!(memory.total_cost(), Decimal::new(30, 1));
assert_eq!(memory.total_duration(), Duration::from_secs(3));
assert_eq!(memory.successful_actions(), 2);
}
}