hefa_core/agent/
mod.rs

1use std::sync::Arc;
2
3use serde_json::Value;
4use thiserror::Error;
5
6use crate::config::{ConfigError, ProviderKind};
7use crate::llm::{
8    ChatMessage, LLMClient, LlmClient, LlmError, LlmRequest, LlmResponse, MessageRole,
9    StructuredOutput, ToolCall,
10};
11use crate::tools::{Tool, ToolError, ToolRegistry, ToolResult};
12use crate::tracing::{SpanContext, SpanStatus, TraceError, Tracer};
13
14pub struct AgentConfig {
15    pub instruction: String,
16    pub provider: ProviderKind,
17    pub model: String,
18    pub structured_output: Option<StructuredOutput>,
19    pub tools: Vec<Box<dyn Tool>>,
20}
21
22#[derive(Debug, Clone)]
23pub struct ToolCallOutcome {
24    pub call: ToolCall,
25    pub result: ToolResult,
26}
27
28#[derive(Debug, Clone)]
29pub struct AgentResult {
30    pub response: LlmResponse,
31    pub tool_outputs: Vec<ToolCallOutcome>,
32}
33
34#[derive(Debug, Error)]
35pub enum AgentError {
36    #[error("configuration error: {0}")]
37    Config(#[from] ConfigError),
38    #[error("llm error: {0}")]
39    Llm(#[from] LlmError),
40    #[error("tool error: {0}")]
41    Tool(#[from] ToolError),
42    #[error("serialization error: {0}")]
43    Serialization(String),
44    #[error("trace error: {0}")]
45    Trace(#[from] TraceError),
46}
47
48pub struct Agent {
49    llm: Arc<dyn LlmClient>,
50    model: String,
51    instruction: String,
52    structured_output: Option<StructuredOutput>,
53    registry: ToolRegistry,
54    tracer: Option<Arc<dyn Tracer>>,
55}
56
57impl Agent {
58    pub fn new(config: AgentConfig) -> Result<Self, AgentError> {
59        let AgentConfig {
60            instruction,
61            provider,
62            model,
63            structured_output,
64            mut tools,
65        } = config;
66        let llm = LLMClient::from_env(provider, &model)?;
67        let mut agent = Self {
68            llm: Arc::new(llm),
69            model,
70            instruction,
71            structured_output,
72            registry: ToolRegistry::new(),
73            tracer: None,
74        };
75        for tool in tools.drain(..) {
76            agent.register_tool_boxed(tool);
77        }
78        Ok(agent)
79    }
80
81    pub fn with_llm(llm: Arc<dyn LlmClient>, config: AgentConfig) -> Self {
82        let AgentConfig {
83            instruction,
84            provider: _,
85            model,
86            structured_output,
87            mut tools,
88        } = config;
89        let mut agent = Self {
90            llm,
91            model,
92            instruction,
93            structured_output,
94            registry: ToolRegistry::new(),
95            tracer: None,
96        };
97        for tool in tools.drain(..) {
98            agent.register_tool_boxed(tool);
99        }
100        agent
101    }
102
103    pub fn with_tracer(mut self, tracer: Arc<dyn Tracer>) -> Self {
104        self.tracer = Some(tracer);
105        self
106    }
107
108    pub fn set_tracer(&mut self, tracer: Arc<dyn Tracer>) {
109        self.tracer = Some(tracer);
110    }
111
112    pub fn register_tool<T>(&mut self, tool: T)
113    where
114        T: crate::tools::Tool + 'static,
115    {
116        self.registry.register(tool);
117    }
118
119    pub fn register_tool_boxed(&mut self, tool: Box<dyn Tool>) {
120        self.registry.register_boxed(tool);
121    }
122
123    pub fn tool_registry_mut(&mut self) -> &mut ToolRegistry {
124        &mut self.registry
125    }
126
127    pub async fn invoke(&self, user_prompt: &str) -> Result<AgentResult, AgentError> {
128        let tracer = self.tracer.clone();
129        let mut span_ctx: Option<SpanContext> = None;
130
131        if let Some(tracer_ref) = tracer.as_ref() {
132            let ctx = tracer_ref
133                .start_trace(
134                    "agent.invoke",
135                    serde_json::json!({
136                        "model": self.model,
137                        "instruction": self.instruction,
138                        "user_prompt": user_prompt,
139                    }),
140                )
141                .await?;
142            span_ctx = Some(ctx);
143        }
144
145        let mut result = self
146            .invoke_inner(user_prompt, tracer.clone(), span_ctx.as_ref())
147            .await;
148
149        if let Some(tracer_ref) = tracer {
150            if let Some(ctx) = span_ctx {
151                let status = match &result {
152                    Ok(_) => SpanStatus::Ok,
153                    Err(err) => SpanStatus::Error(err.to_string()),
154                };
155                if let Err(trace_err) = tracer_ref.end_span(ctx, status).await {
156                    result = Err(AgentError::Trace(trace_err));
157                }
158            }
159        }
160
161        result
162    }
163
164    async fn invoke_inner(
165        &self,
166        user_prompt: &str,
167        tracer: Option<Arc<dyn Tracer>>,
168        span_ctx: Option<&SpanContext>,
169    ) -> Result<AgentResult, AgentError> {
170        let mut messages = vec![
171            ChatMessage {
172                role: MessageRole::System,
173                content: self.instruction.clone(),
174                tool_call_id: None,
175            },
176            ChatMessage {
177                role: MessageRole::User,
178                content: user_prompt.to_string(),
179                tool_call_id: None,
180            },
181        ];
182        if let (Some(tracer_ref), Some(span)) = (&tracer, span_ctx) {
183            tracer_ref
184                .record_event(
185                    &span.id,
186                    "llm.request",
187                    serde_json::json!({
188                        "model": self.model,
189                        "instruction": self.instruction,
190                        "messages": messages.len()
191                    }),
192                )
193                .await?;
194        }
195
196        let mut response = self
197            .llm
198            .invoke(LlmRequest {
199                messages: messages.clone(),
200                structured_output: self.structured_output.clone(),
201                tools: vec![],
202                tool_choice: None,
203            })
204            .await?;
205
206        let mut tool_outputs = Vec::new();
207
208        if !response.tool_calls.is_empty() {
209            for call in &response.tool_calls {
210                let tool_result = self
211                    .registry
212                    .invoke(&call.name, call.arguments.clone())
213                    .await?;
214                let bridge = ToolBridge {
215                    tool_call_id: call.id.clone(),
216                    output: tool_result.content.clone(),
217                };
218                let output = serde_json::to_string(&bridge)
219                    .map_err(|e| AgentError::Serialization(e.to_string()))?;
220                let tool_message = ChatMessage {
221                    role: MessageRole::Tool,
222                    content: output,
223                    tool_call_id: Some(call.id.clone()),
224                };
225                messages.push(tool_message);
226                tool_outputs.push(ToolCallOutcome {
227                    call: call.clone(),
228                    result: tool_result,
229                });
230
231                if let (Some(tracer_ref), Some(span)) = (&tracer, span_ctx) {
232                    tracer_ref
233                        .record_event(
234                            &span.id,
235                            "tool.call",
236                            serde_json::json!({
237                                "tool": call.name,
238                                "tool_call_id": call.id,
239                                "model": self.model,
240                            }),
241                        )
242                        .await?;
243                }
244            }
245
246            if let (Some(tracer_ref), Some(span)) = (&tracer, span_ctx) {
247                tracer_ref
248                    .record_event(
249                        &span.id,
250                        "llm.reenter",
251                        serde_json::json!({ "messages": messages.len() }),
252                    )
253                    .await?;
254            }
255
256            response = self
257                .llm
258                .invoke(LlmRequest {
259                    messages,
260                    structured_output: self.structured_output.clone(),
261                    tools: vec![],
262                    tool_choice: None,
263                })
264                .await?;
265        }
266
267        Ok(AgentResult {
268            response,
269            tool_outputs,
270        })
271    }
272}
273
274#[derive(Debug, Clone, serde::Serialize)]
275struct ToolBridge {
276    tool_call_id: String,
277    output: Value,
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use crate::llm::{LlmResponse, ToolCall};
284    use crate::tools::{Tool, ToolResult};
285    use crate::tracing::{SpanContext, Tracer};
286    use async_trait::async_trait;
287    use serde_json::json;
288    use tokio::sync::Mutex;
289
290    struct MockLlm {
291        responses: Mutex<std::collections::VecDeque<LlmResponse>>,
292        captured: Mutex<Vec<LlmRequest>>,
293    }
294
295    impl MockLlm {
296        fn new(responses: Vec<LlmResponse>) -> Self {
297            Self {
298                responses: Mutex::new(responses.into()),
299                captured: Mutex::new(vec![]),
300            }
301        }
302
303        async fn captured(&self) -> Vec<LlmRequest> {
304            self.captured.lock().await.clone()
305        }
306    }
307
308    #[async_trait]
309    impl LlmClient for MockLlm {
310        async fn invoke(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
311            self.captured.lock().await.push(request);
312            self.responses
313                .lock()
314                .await
315                .pop_front()
316                .ok_or_else(|| LlmError::Unsupported("no mock response".into()))
317        }
318    }
319
320    #[tokio::test]
321    async fn invokes_llm_with_instruction_and_prompt() {
322        let llm = Arc::new(MockLlm::new(vec![LlmResponse {
323            content: "hello".into(),
324            tool_calls: vec![],
325        }]));
326        let config = AgentConfig {
327            instruction: "You are helpful".into(),
328            provider: ProviderKind::OpenAi,
329            model: "gpt-4o".into(),
330            structured_output: None,
331            tools: Vec::new(),
332        };
333        let agent = Agent::with_llm(llm.clone(), config);
334        let result = agent.invoke("Hi").await.expect("agent result");
335        assert_eq!(result.response.content, "hello");
336        let captured = llm.captured().await;
337        assert_eq!(captured.len(), 1);
338        assert_eq!(captured[0].messages[0].role, MessageRole::System);
339        assert_eq!(captured[0].messages[1].role, MessageRole::User);
340    }
341
342    struct EchoTool;
343
344    #[async_trait]
345    impl Tool for EchoTool {
346        fn name(&self) -> &'static str {
347            "echo"
348        }
349
350        fn json_schema(&self) -> Value {
351            json!({})
352        }
353
354        async fn call(&self, args: Value) -> Result<ToolResult, ToolError> {
355            Ok(ToolResult { content: args })
356        }
357    }
358
359    #[tokio::test]
360    async fn executes_tool_call_and_reenters_llm() {
361        let llm = Arc::new(MockLlm::new(vec![
362            LlmResponse {
363                content: "CALL_TOOL".into(),
364                tool_calls: vec![ToolCall {
365                    id: "call-1".into(),
366                    name: "echo".into(),
367                    arguments: json!({"message": "hello"}),
368                }],
369            },
370            LlmResponse {
371                content: "final answer".into(),
372                tool_calls: vec![],
373            },
374        ]));
375        let config = AgentConfig {
376            instruction: "helper".into(),
377            provider: ProviderKind::OpenAi,
378            model: "gpt-4o".into(),
379            structured_output: None,
380            tools: Vec::new(),
381        };
382        let mut agent = Agent::with_llm(llm.clone(), config);
383        agent.register_tool(EchoTool);
384        let result = agent.invoke("Hi").await.expect("agent result");
385        assert_eq!(result.response.content, "final answer");
386        assert_eq!(result.tool_outputs.len(), 1);
387        assert_eq!(result.tool_outputs[0].call.name, "echo");
388        assert_eq!(
389            result.tool_outputs[0].result.content,
390            json!({"message": "hello"})
391        );
392        let captured = llm.captured().await;
393        assert_eq!(captured.len(), 2);
394        assert_eq!(captured[1].messages.last().unwrap().role, MessageRole::Tool);
395    }
396
397    #[derive(Default)]
398    struct RecordingTracer {
399        events: Mutex<Vec<String>>,
400    }
401
402    #[async_trait]
403    impl Tracer for RecordingTracer {
404        async fn start_span(
405            &self,
406            name: &str,
407            attributes: Value,
408            parent: Option<String>,
409        ) -> Result<SpanContext, TraceError> {
410            self.events
411                .lock()
412                .await
413                .push(format!("start:{name}:{attributes}"));
414            Ok(SpanContext::new(name, parent, attributes))
415        }
416
417        async fn end_span(&self, span: SpanContext, status: SpanStatus) -> Result<(), TraceError> {
418            self.events
419                .lock()
420                .await
421                .push(format!("end:{}:{:?}", span.name, status));
422            Ok(())
423        }
424
425        async fn record_event(
426            &self,
427            span_id: &str,
428            name: &str,
429            attributes: Value,
430        ) -> Result<(), TraceError> {
431            self.events
432                .lock()
433                .await
434                .push(format!("event:{span_id}:{name}:{attributes}",));
435            Ok(())
436        }
437    }
438
439    #[tokio::test]
440    async fn agents_can_record_traces() {
441        let llm = Arc::new(MockLlm::new(vec![LlmResponse {
442            content: "ok".into(),
443            tool_calls: vec![],
444        }]));
445        let config = AgentConfig {
446            instruction: "helper".into(),
447            provider: ProviderKind::OpenAi,
448            model: "gpt-4o".into(),
449            structured_output: None,
450            tools: Vec::new(),
451        };
452        let tracer = Arc::new(RecordingTracer::default());
453        let agent = Agent::with_llm(llm.clone(), config).with_tracer(tracer.clone());
454        agent.invoke("Hi").await.expect("agent");
455        let events = tracer.events.lock().await;
456        assert!(events.iter().any(|e| e.starts_with("start:agent.invoke")));
457        assert!(events.iter().any(|e| e.starts_with("end:agent.invoke")));
458    }
459}