hehe_agent/
response.rs

1use hehe_core::{Id, Metadata};
2use hehe_tools::ToolOutput;
3use serde::{Deserialize, Serialize};
4
5#[derive(Clone, Debug, Serialize, Deserialize)]
6pub struct ToolCallRecord {
7    pub id: String,
8    pub name: String,
9    pub input: serde_json::Value,
10    pub output: String,
11    pub is_error: bool,
12    pub duration_ms: u64,
13}
14
15#[derive(Clone, Debug, Serialize, Deserialize)]
16pub struct AgentResponse {
17    pub session_id: Id,
18    pub text: String,
19    pub tool_calls: Vec<ToolCallRecord>,
20    pub iterations: usize,
21    pub metadata: Metadata,
22}
23
24impl AgentResponse {
25    pub fn new(session_id: Id, text: impl Into<String>) -> Self {
26        Self {
27            session_id,
28            text: text.into(),
29            tool_calls: Vec::new(),
30            iterations: 1,
31            metadata: Metadata::new(),
32        }
33    }
34
35    pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCallRecord>) -> Self {
36        self.tool_calls = tool_calls;
37        self
38    }
39
40    pub fn with_iterations(mut self, iterations: usize) -> Self {
41        self.iterations = iterations;
42        self
43    }
44
45    pub fn with_metadata<K: Into<String>, V: Serialize>(mut self, key: K, value: V) -> Self {
46        self.metadata.insert(key, value);
47        self
48    }
49
50    pub fn text(&self) -> &str {
51        &self.text
52    }
53
54    pub fn has_tool_calls(&self) -> bool {
55        !self.tool_calls.is_empty()
56    }
57
58    pub fn tool_call_count(&self) -> usize {
59        self.tool_calls.len()
60    }
61
62    pub fn successful_tool_calls(&self) -> impl Iterator<Item = &ToolCallRecord> {
63        self.tool_calls.iter().filter(|tc| !tc.is_error)
64    }
65
66    pub fn failed_tool_calls(&self) -> impl Iterator<Item = &ToolCallRecord> {
67        self.tool_calls.iter().filter(|tc| tc.is_error)
68    }
69}
70
71impl ToolCallRecord {
72    pub fn success(
73        id: impl Into<String>,
74        name: impl Into<String>,
75        input: serde_json::Value,
76        output: &ToolOutput,
77        duration_ms: u64,
78    ) -> Self {
79        Self {
80            id: id.into(),
81            name: name.into(),
82            input,
83            output: output.content.clone(),
84            is_error: output.is_error,
85            duration_ms,
86        }
87    }
88
89    pub fn error(
90        id: impl Into<String>,
91        name: impl Into<String>,
92        input: serde_json::Value,
93        error_msg: impl Into<String>,
94        duration_ms: u64,
95    ) -> Self {
96        Self {
97            id: id.into(),
98            name: name.into(),
99            input,
100            output: error_msg.into(),
101            is_error: true,
102            duration_ms,
103        }
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn test_response_basic() {
113        let response = AgentResponse::new(Id::new(), "Hello, world!");
114        assert_eq!(response.text(), "Hello, world!");
115        assert!(!response.has_tool_calls());
116    }
117
118    #[test]
119    fn test_response_with_tool_calls() {
120        let tool_call = ToolCallRecord {
121            id: "call_1".to_string(),
122            name: "read_file".to_string(),
123            input: serde_json::json!({"path": "/tmp/test.txt"}),
124            output: "file content".to_string(),
125            is_error: false,
126            duration_ms: 100,
127        };
128
129        let response = AgentResponse::new(Id::new(), "Done!")
130            .with_tool_calls(vec![tool_call])
131            .with_iterations(2);
132
133        assert!(response.has_tool_calls());
134        assert_eq!(response.tool_call_count(), 1);
135        assert_eq!(response.iterations, 2);
136    }
137
138    #[test]
139    fn test_tool_call_record() {
140        let record = ToolCallRecord::error(
141            "call_1",
142            "write_file",
143            serde_json::json!({}),
144            "Permission denied",
145            50,
146        );
147
148        assert!(record.is_error);
149        assert_eq!(record.output, "Permission denied");
150    }
151}