use std::time::Duration;
use chrono::{DateTime, Utc};
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::llm::ChatMessage;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ActionRecord {
pub id: Uuid,
pub sequence: u32,
pub tool_name: String,
pub input: serde_json::Value,
pub output_raw: Option<String>,
pub output_sanitized: Option<serde_json::Value>,
pub sanitization_warnings: Vec<String>,
pub cost: Option<Decimal>,
pub duration: Duration,
pub success: bool,
pub error: Option<String>,
pub executed_at: DateTime<Utc>,
}
impl ActionRecord {
pub fn new(sequence: u32, tool_name: impl Into<String>, input: serde_json::Value) -> Self {
Self {
id: Uuid::new_v4(),
sequence,
tool_name: tool_name.into(),
input,
output_raw: None,
output_sanitized: None,
sanitization_warnings: Vec::new(),
cost: None,
duration: Duration::ZERO,
success: false,
error: None,
executed_at: Utc::now(),
}
}
pub fn succeed(
mut self,
output_sanitized: Option<String>,
output_raw: serde_json::Value,
duration: Duration,
) -> Self {
self.success = true;
self.output_raw = Some(serde_json::to_string_pretty(&output_raw).unwrap_or_default());
self.output_sanitized = output_sanitized.map(serde_json::Value::String);
self.duration = duration;
self
}
pub fn fail(mut self, error: impl Into<String>, duration: Duration) -> Self {
self.success = false;
self.error = Some(error.into());
self.duration = duration;
self
}
pub fn with_warnings(mut self, warnings: Vec<String>) -> Self {
self.sanitization_warnings = warnings;
self
}
pub fn with_cost(mut self, cost: Decimal) -> Self {
self.cost = Some(cost);
self
}
}
#[derive(Debug, Clone, Default)]
pub struct ConversationMemory {
messages: Vec<ChatMessage>,
max_messages: usize,
}
impl ConversationMemory {
pub fn new(max_messages: usize) -> Self {
Self {
messages: Vec::new(),
max_messages,
}
}
pub fn add(&mut self, message: ChatMessage) {
self.messages.push(message);
while self.messages.len() > self.max_messages {
if self.messages.first().map(|m| m.role) == Some(crate::llm::Role::System) {
if self.messages.len() > 1 {
self.messages.remove(1);
} else {
break;
}
} else {
self.messages.remove(0);
}
}
}
pub fn messages(&self) -> &[ChatMessage] {
&self.messages
}
pub fn last_n(&self, n: usize) -> &[ChatMessage] {
let start = self.messages.len().saturating_sub(n);
&self.messages[start..]
}
pub fn clear(&mut self) {
self.messages.clear();
}
pub fn len(&self) -> usize {
self.messages.len()
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct Memory {
pub job_id: Uuid,
pub conversation: ConversationMemory,
pub actions: Vec<ActionRecord>,
next_sequence: u32,
}
impl Memory {
pub fn new(job_id: Uuid) -> Self {
Self {
job_id,
conversation: ConversationMemory::new(100),
actions: Vec::new(),
next_sequence: 0,
}
}
pub fn add_message(&mut self, message: ChatMessage) {
self.conversation.add(message);
}
pub fn create_action(
&mut self,
tool_name: impl Into<String>,
input: serde_json::Value,
) -> ActionRecord {
let seq = self.next_sequence;
self.next_sequence += 1;
ActionRecord::new(seq, tool_name, input)
}
pub fn record_action(&mut self, action: ActionRecord) {
self.actions.push(action);
}
pub fn total_cost(&self) -> Decimal {
self.actions
.iter()
.filter_map(|a| a.cost)
.fold(Decimal::ZERO, |acc, c| acc + c)
}
pub fn total_duration(&self) -> Duration {
self.actions
.iter()
.map(|a| a.duration)
.fold(Duration::ZERO, |acc, d| acc + d)
}
pub fn successful_actions(&self) -> usize {
self.actions.iter().filter(|a| a.success).count()
}
pub fn failed_actions(&self) -> usize {
self.actions.iter().filter(|a| !a.success).count()
}
pub fn last_action(&self) -> Option<&ActionRecord> {
self.actions.last()
}
pub fn actions_by_tool(&self, tool_name: &str) -> Vec<&ActionRecord> {
self.actions
.iter()
.filter(|a| a.tool_name == tool_name)
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_action_record() {
let action = ActionRecord::new(0, "test", serde_json::json!({"key": "value"}));
assert_eq!(action.sequence, 0); assert!(!action.success);
let action = action.succeed(
Some("raw".to_string()),
serde_json::json!({"result": "ok"}),
Duration::from_millis(100),
);
assert!(action.success); }
#[test]
fn test_conversation_memory() {
let mut memory = ConversationMemory::new(3);
memory.add(ChatMessage::user("Hello"));
memory.add(ChatMessage::assistant("Hi"));
memory.add(ChatMessage::user("How are you?"));
memory.add(ChatMessage::assistant("Good!"));
assert_eq!(memory.len(), 3); }
#[test]
fn test_memory_totals() {
let mut memory = Memory::new(Uuid::new_v4());
let action1 = memory
.create_action("tool1", serde_json::json!({}))
.succeed(None, serde_json::json!({}), Duration::from_secs(1))
.with_cost(Decimal::new(10, 1));
memory.record_action(action1);
let action2 = memory
.create_action("tool2", serde_json::json!({}))
.succeed(None, serde_json::json!({}), Duration::from_secs(2))
.with_cost(Decimal::new(20, 1));
memory.record_action(action2);
assert_eq!(memory.total_cost(), Decimal::new(30, 1)); assert_eq!(memory.total_duration(), Duration::from_secs(3)); assert_eq!(memory.successful_actions(), 2); }
#[test]
fn test_action_record_fail() {
let action = ActionRecord::new(1, "broken_tool", serde_json::json!({"x": 1}));
let action = action.fail("something went wrong", Duration::from_millis(50));
assert!(!action.success); assert_eq!(action.error.as_deref(), Some("something went wrong")); assert_eq!(action.duration, Duration::from_millis(50)); assert!(action.output_raw.is_none()); assert!(action.output_sanitized.is_none()); }
#[test]
fn test_action_record_with_warnings() {
let action = ActionRecord::new(0, "risky_tool", serde_json::json!({}));
let action = action.with_warnings(vec!["suspicious pattern".into(), "possible xss".into()]);
assert_eq!(action.sanitization_warnings.len(), 2); assert_eq!(action.sanitization_warnings[0], "suspicious pattern"); assert_eq!(action.sanitization_warnings[1], "possible xss"); }
#[test]
fn test_action_record_with_cost() {
let action = ActionRecord::new(0, "expensive_tool", serde_json::json!({}));
let cost = Decimal::new(42, 2); let action = action.with_cost(cost);
assert_eq!(action.cost, Some(Decimal::new(42, 2))); }
#[test]
fn test_action_record_new_defaults() {
let action = ActionRecord::new(5, "my_tool", serde_json::json!({"key": "val"}));
assert_eq!(action.sequence, 5); assert_eq!(action.tool_name, "my_tool"); assert_eq!(action.input, serde_json::json!({"key": "val"})); assert!(!action.success); assert!(action.output_raw.is_none()); assert!(action.output_sanitized.is_none()); assert!(action.sanitization_warnings.is_empty()); assert!(action.cost.is_none()); assert_eq!(action.duration, Duration::ZERO); assert!(action.error.is_none()); }
#[test]
fn test_action_record_succeed_sets_fields() {
let action = ActionRecord::new(0, "tool", serde_json::json!({}));
let action = action.succeed(
Some("sanitized output".into()),
serde_json::json!({"clean": true}),
Duration::from_secs(7),
);
assert!(action.success); let expected_raw =
serde_json::to_string_pretty(&serde_json::json!({"clean": true})).unwrap(); assert_eq!(action.output_raw.as_deref(), Some(expected_raw.as_str())); assert_eq!(
action.output_sanitized,
Some(serde_json::json!("sanitized output"))
);
assert_eq!(action.duration, Duration::from_secs(7)); }
#[test]
fn test_conversation_memory_clear() {
let mut mem = ConversationMemory::new(10);
mem.add(ChatMessage::user("hello"));
mem.add(ChatMessage::assistant("hi"));
assert_eq!(mem.len(), 2); assert!(!mem.is_empty());
mem.clear();
assert_eq!(mem.len(), 0); assert!(mem.is_empty()); assert!(mem.messages().is_empty()); }
#[test]
fn test_conversation_memory_last_n() {
let mut mem = ConversationMemory::new(10);
mem.add(ChatMessage::user("one"));
mem.add(ChatMessage::assistant("two"));
mem.add(ChatMessage::user("three"));
mem.add(ChatMessage::assistant("four"));
let last_2 = mem.last_n(2);
assert_eq!(last_2.len(), 2); assert_eq!(last_2[0].content, "three"); assert_eq!(last_2[1].content, "four");
let last_100 = mem.last_n(100);
assert_eq!(last_100.len(), 4); }
#[test]
fn test_conversation_memory_last_n_empty() {
let mem = ConversationMemory::new(10);
let result = mem.last_n(5);
assert!(result.is_empty()); }
#[test]
fn test_conversation_memory_preserves_system_message_on_trim() {
let mut mem = ConversationMemory::new(3);
mem.add(ChatMessage::system("You are helpful"));
mem.add(ChatMessage::user("msg1"));
mem.add(ChatMessage::user("msg2"));
mem.add(ChatMessage::user("msg3"));
assert_eq!(mem.len(), 3); assert_eq!(mem.messages()[0].role, crate::llm::Role::System); assert_eq!(mem.messages()[0].content, "You are helpful"); assert_eq!(mem.messages()[1].content, "msg2"); assert_eq!(mem.messages()[2].content, "msg3"); }
#[test]
fn test_conversation_memory_trims_non_system_first() {
let mut mem = ConversationMemory::new(2);
mem.add(ChatMessage::system("sys"));
mem.add(ChatMessage::user("a"));
mem.add(ChatMessage::user("b"));
assert_eq!(mem.len(), 2); assert_eq!(mem.messages()[0].role, crate::llm::Role::System); assert_eq!(mem.messages()[1].content, "b"); }
#[test]
fn test_conversation_memory_max_one_with_system_does_not_loop() {
let mut mem = ConversationMemory::new(1);
mem.add(ChatMessage::system("sys"));
mem.add(ChatMessage::user("hello"));
assert!(mem.len() <= 2); }
#[test]
fn test_memory_failed_actions() {
let mut memory = Memory::new(Uuid::new_v4());
let ok = memory.create_action("good", serde_json::json!({})).succeed(
None,
serde_json::json!({}),
Duration::from_millis(1),
);
memory.record_action(ok);
let err = memory
.create_action("bad", serde_json::json!({}))
.fail("oops", Duration::from_millis(2));
memory.record_action(err);
assert_eq!(memory.successful_actions(), 1); assert_eq!(memory.failed_actions(), 1); }
#[test]
fn test_memory_last_action() {
let mut memory = Memory::new(Uuid::new_v4());
assert!(memory.last_action().is_none());
let a1 = memory
.create_action("first", serde_json::json!({}))
.succeed(None, serde_json::json!({}), Duration::ZERO);
memory.record_action(a1);
let a2 = memory
.create_action("second", serde_json::json!({}))
.fail("nope", Duration::ZERO);
memory.record_action(a2);
let last = memory.last_action().unwrap(); assert_eq!(last.tool_name, "second"); }
#[test]
fn test_memory_actions_by_tool() {
let mut memory = Memory::new(Uuid::new_v4());
for _ in 0..3 {
let a = memory
.create_action("shell", serde_json::json!({}))
.succeed(None, serde_json::json!({}), Duration::ZERO);
memory.record_action(a);
}
let a = memory.create_action("http", serde_json::json!({})).succeed(
None,
serde_json::json!({}),
Duration::ZERO,
);
memory.record_action(a);
assert_eq!(memory.actions_by_tool("shell").len(), 3); assert_eq!(memory.actions_by_tool("http").len(), 1); assert_eq!(memory.actions_by_tool("nonexistent").len(), 0); }
#[test]
fn test_memory_create_action_increments_sequence() {
let mut memory = Memory::new(Uuid::new_v4());
let a0 = memory.create_action("t", serde_json::json!({}));
assert_eq!(a0.sequence, 0);
let a1 = memory.create_action("t", serde_json::json!({}));
assert_eq!(a1.sequence, 1);
let a2 = memory.create_action("t", serde_json::json!({}));
assert_eq!(a2.sequence, 2); }
#[test]
fn test_memory_add_message_delegates_to_conversation() {
let mut memory = Memory::new(Uuid::new_v4());
assert!(memory.conversation.is_empty());
memory.add_message(ChatMessage::user("hello"));
memory.add_message(ChatMessage::assistant("hi"));
assert_eq!(memory.conversation.len(), 2); assert_eq!(memory.conversation.messages()[0].content, "hello"); }
#[test]
fn test_memory_total_cost_with_no_cost_actions() {
let mut memory = Memory::new(Uuid::new_v4());
let a = memory
.create_action("free_tool", serde_json::json!({}))
.succeed(None, serde_json::json!({}), Duration::ZERO);
memory.record_action(a);
assert_eq!(memory.total_cost(), Decimal::ZERO); }
#[test]
fn test_memory_total_duration_mixed() {
let mut memory = Memory::new(Uuid::new_v4());
let a1 = memory.create_action("t1", serde_json::json!({})).succeed(
None,
serde_json::json!({}),
Duration::from_millis(100),
);
memory.record_action(a1);
let a2 = memory
.create_action("t2", serde_json::json!({}))
.fail("err", Duration::from_millis(200));
memory.record_action(a2);
assert_eq!(memory.total_duration(), Duration::from_millis(300)); }
}