1use std::sync::Arc;
2
3use serde_json::Value;
4use thiserror::Error;
5
6use crate::config::{ConfigError, ProviderKind};
7use crate::llm::{
8 ChatMessage, LLMClient, LlmClient, LlmError, LlmRequest, LlmResponse, MessageRole,
9 StructuredOutput, ToolCall,
10};
11use crate::tools::{Tool, ToolError, ToolRegistry, ToolResult};
12use crate::tracing::{SpanContext, SpanStatus, TraceError, Tracer};
13
14pub struct AgentConfig {
15 pub instruction: String,
16 pub provider: ProviderKind,
17 pub model: String,
18 pub structured_output: Option<StructuredOutput>,
19 pub tools: Vec<Box<dyn Tool>>,
20}
21
22#[derive(Debug, Clone)]
23pub struct ToolCallOutcome {
24 pub call: ToolCall,
25 pub result: ToolResult,
26}
27
28#[derive(Debug, Clone)]
29pub struct AgentResult {
30 pub response: LlmResponse,
31 pub tool_outputs: Vec<ToolCallOutcome>,
32}
33
34#[derive(Debug, Error)]
35pub enum AgentError {
36 #[error("configuration error: {0}")]
37 Config(#[from] ConfigError),
38 #[error("llm error: {0}")]
39 Llm(#[from] LlmError),
40 #[error("tool error: {0}")]
41 Tool(#[from] ToolError),
42 #[error("serialization error: {0}")]
43 Serialization(String),
44 #[error("trace error: {0}")]
45 Trace(#[from] TraceError),
46}
47
48pub struct Agent {
49 llm: Arc<dyn LlmClient>,
50 model: String,
51 instruction: String,
52 structured_output: Option<StructuredOutput>,
53 registry: ToolRegistry,
54 tracer: Option<Arc<dyn Tracer>>,
55}
56
57impl Agent {
58 pub fn new(config: AgentConfig) -> Result<Self, AgentError> {
59 let AgentConfig {
60 instruction,
61 provider,
62 model,
63 structured_output,
64 mut tools,
65 } = config;
66 let llm = LLMClient::from_env(provider, &model)?;
67 let mut agent = Self {
68 llm: Arc::new(llm),
69 model,
70 instruction,
71 structured_output,
72 registry: ToolRegistry::new(),
73 tracer: None,
74 };
75 for tool in tools.drain(..) {
76 agent.register_tool_boxed(tool);
77 }
78 Ok(agent)
79 }
80
81 pub fn with_llm(llm: Arc<dyn LlmClient>, config: AgentConfig) -> Self {
82 let AgentConfig {
83 instruction,
84 provider: _,
85 model,
86 structured_output,
87 mut tools,
88 } = config;
89 let mut agent = Self {
90 llm,
91 model,
92 instruction,
93 structured_output,
94 registry: ToolRegistry::new(),
95 tracer: None,
96 };
97 for tool in tools.drain(..) {
98 agent.register_tool_boxed(tool);
99 }
100 agent
101 }
102
103 pub fn with_tracer(mut self, tracer: Arc<dyn Tracer>) -> Self {
104 self.tracer = Some(tracer);
105 self
106 }
107
108 pub fn set_tracer(&mut self, tracer: Arc<dyn Tracer>) {
109 self.tracer = Some(tracer);
110 }
111
112 pub fn register_tool<T>(&mut self, tool: T)
113 where
114 T: crate::tools::Tool + 'static,
115 {
116 self.registry.register(tool);
117 }
118
119 pub fn register_tool_boxed(&mut self, tool: Box<dyn Tool>) {
120 self.registry.register_boxed(tool);
121 }
122
123 pub fn tool_registry_mut(&mut self) -> &mut ToolRegistry {
124 &mut self.registry
125 }
126
127 pub async fn invoke(&self, user_prompt: &str) -> Result<AgentResult, AgentError> {
128 let tracer = self.tracer.clone();
129 let mut span_ctx: Option<SpanContext> = None;
130
131 if let Some(tracer_ref) = tracer.as_ref() {
132 let ctx = tracer_ref
133 .start_trace(
134 "agent.invoke",
135 serde_json::json!({
136 "model": self.model,
137 "instruction": self.instruction,
138 "user_prompt": user_prompt,
139 }),
140 )
141 .await?;
142 span_ctx = Some(ctx);
143 }
144
145 let mut result = self
146 .invoke_inner(user_prompt, tracer.clone(), span_ctx.as_ref())
147 .await;
148
149 if let Some(tracer_ref) = tracer {
150 if let Some(ctx) = span_ctx {
151 let status = match &result {
152 Ok(_) => SpanStatus::Ok,
153 Err(err) => SpanStatus::Error(err.to_string()),
154 };
155 if let Err(trace_err) = tracer_ref.end_span(ctx, status).await {
156 result = Err(AgentError::Trace(trace_err));
157 }
158 }
159 }
160
161 result
162 }
163
164 async fn invoke_inner(
165 &self,
166 user_prompt: &str,
167 tracer: Option<Arc<dyn Tracer>>,
168 span_ctx: Option<&SpanContext>,
169 ) -> Result<AgentResult, AgentError> {
170 let mut messages = vec![
171 ChatMessage {
172 role: MessageRole::System,
173 content: self.instruction.clone(),
174 tool_call_id: None,
175 },
176 ChatMessage {
177 role: MessageRole::User,
178 content: user_prompt.to_string(),
179 tool_call_id: None,
180 },
181 ];
182 if let (Some(tracer_ref), Some(span)) = (&tracer, span_ctx) {
183 tracer_ref
184 .record_event(
185 &span.id,
186 "llm.request",
187 serde_json::json!({
188 "model": self.model,
189 "instruction": self.instruction,
190 "messages": messages.len()
191 }),
192 )
193 .await?;
194 }
195
196 let mut response = self
197 .llm
198 .invoke(LlmRequest {
199 messages: messages.clone(),
200 structured_output: self.structured_output.clone(),
201 tools: vec![],
202 tool_choice: None,
203 })
204 .await?;
205
206 let mut tool_outputs = Vec::new();
207
208 if !response.tool_calls.is_empty() {
209 for call in &response.tool_calls {
210 let tool_result = self
211 .registry
212 .invoke(&call.name, call.arguments.clone())
213 .await?;
214 let bridge = ToolBridge {
215 tool_call_id: call.id.clone(),
216 output: tool_result.content.clone(),
217 };
218 let output = serde_json::to_string(&bridge)
219 .map_err(|e| AgentError::Serialization(e.to_string()))?;
220 let tool_message = ChatMessage {
221 role: MessageRole::Tool,
222 content: output,
223 tool_call_id: Some(call.id.clone()),
224 };
225 messages.push(tool_message);
226 tool_outputs.push(ToolCallOutcome {
227 call: call.clone(),
228 result: tool_result,
229 });
230
231 if let (Some(tracer_ref), Some(span)) = (&tracer, span_ctx) {
232 tracer_ref
233 .record_event(
234 &span.id,
235 "tool.call",
236 serde_json::json!({
237 "tool": call.name,
238 "tool_call_id": call.id,
239 "model": self.model,
240 }),
241 )
242 .await?;
243 }
244 }
245
246 if let (Some(tracer_ref), Some(span)) = (&tracer, span_ctx) {
247 tracer_ref
248 .record_event(
249 &span.id,
250 "llm.reenter",
251 serde_json::json!({ "messages": messages.len() }),
252 )
253 .await?;
254 }
255
256 response = self
257 .llm
258 .invoke(LlmRequest {
259 messages,
260 structured_output: self.structured_output.clone(),
261 tools: vec![],
262 tool_choice: None,
263 })
264 .await?;
265 }
266
267 Ok(AgentResult {
268 response,
269 tool_outputs,
270 })
271 }
272}
273
274#[derive(Debug, Clone, serde::Serialize)]
275struct ToolBridge {
276 tool_call_id: String,
277 output: Value,
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283 use crate::llm::{LlmResponse, ToolCall};
284 use crate::tools::{Tool, ToolResult};
285 use crate::tracing::{SpanContext, Tracer};
286 use async_trait::async_trait;
287 use serde_json::json;
288 use tokio::sync::Mutex;
289
290 struct MockLlm {
291 responses: Mutex<std::collections::VecDeque<LlmResponse>>,
292 captured: Mutex<Vec<LlmRequest>>,
293 }
294
295 impl MockLlm {
296 fn new(responses: Vec<LlmResponse>) -> Self {
297 Self {
298 responses: Mutex::new(responses.into()),
299 captured: Mutex::new(vec![]),
300 }
301 }
302
303 async fn captured(&self) -> Vec<LlmRequest> {
304 self.captured.lock().await.clone()
305 }
306 }
307
308 #[async_trait]
309 impl LlmClient for MockLlm {
310 async fn invoke(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
311 self.captured.lock().await.push(request);
312 self.responses
313 .lock()
314 .await
315 .pop_front()
316 .ok_or_else(|| LlmError::Unsupported("no mock response".into()))
317 }
318 }
319
320 #[tokio::test]
321 async fn invokes_llm_with_instruction_and_prompt() {
322 let llm = Arc::new(MockLlm::new(vec![LlmResponse {
323 content: "hello".into(),
324 tool_calls: vec![],
325 }]));
326 let config = AgentConfig {
327 instruction: "You are helpful".into(),
328 provider: ProviderKind::OpenAi,
329 model: "gpt-4o".into(),
330 structured_output: None,
331 tools: Vec::new(),
332 };
333 let agent = Agent::with_llm(llm.clone(), config);
334 let result = agent.invoke("Hi").await.expect("agent result");
335 assert_eq!(result.response.content, "hello");
336 let captured = llm.captured().await;
337 assert_eq!(captured.len(), 1);
338 assert_eq!(captured[0].messages[0].role, MessageRole::System);
339 assert_eq!(captured[0].messages[1].role, MessageRole::User);
340 }
341
342 struct EchoTool;
343
344 #[async_trait]
345 impl Tool for EchoTool {
346 fn name(&self) -> &'static str {
347 "echo"
348 }
349
350 fn json_schema(&self) -> Value {
351 json!({})
352 }
353
354 async fn call(&self, args: Value) -> Result<ToolResult, ToolError> {
355 Ok(ToolResult { content: args })
356 }
357 }
358
359 #[tokio::test]
360 async fn executes_tool_call_and_reenters_llm() {
361 let llm = Arc::new(MockLlm::new(vec![
362 LlmResponse {
363 content: "CALL_TOOL".into(),
364 tool_calls: vec![ToolCall {
365 id: "call-1".into(),
366 name: "echo".into(),
367 arguments: json!({"message": "hello"}),
368 }],
369 },
370 LlmResponse {
371 content: "final answer".into(),
372 tool_calls: vec![],
373 },
374 ]));
375 let config = AgentConfig {
376 instruction: "helper".into(),
377 provider: ProviderKind::OpenAi,
378 model: "gpt-4o".into(),
379 structured_output: None,
380 tools: Vec::new(),
381 };
382 let mut agent = Agent::with_llm(llm.clone(), config);
383 agent.register_tool(EchoTool);
384 let result = agent.invoke("Hi").await.expect("agent result");
385 assert_eq!(result.response.content, "final answer");
386 assert_eq!(result.tool_outputs.len(), 1);
387 assert_eq!(result.tool_outputs[0].call.name, "echo");
388 assert_eq!(
389 result.tool_outputs[0].result.content,
390 json!({"message": "hello"})
391 );
392 let captured = llm.captured().await;
393 assert_eq!(captured.len(), 2);
394 assert_eq!(captured[1].messages.last().unwrap().role, MessageRole::Tool);
395 }
396
397 #[derive(Default)]
398 struct RecordingTracer {
399 events: Mutex<Vec<String>>,
400 }
401
402 #[async_trait]
403 impl Tracer for RecordingTracer {
404 async fn start_span(
405 &self,
406 name: &str,
407 attributes: Value,
408 parent: Option<String>,
409 ) -> Result<SpanContext, TraceError> {
410 self.events
411 .lock()
412 .await
413 .push(format!("start:{name}:{attributes}"));
414 Ok(SpanContext::new(name, parent, attributes))
415 }
416
417 async fn end_span(&self, span: SpanContext, status: SpanStatus) -> Result<(), TraceError> {
418 self.events
419 .lock()
420 .await
421 .push(format!("end:{}:{:?}", span.name, status));
422 Ok(())
423 }
424
425 async fn record_event(
426 &self,
427 span_id: &str,
428 name: &str,
429 attributes: Value,
430 ) -> Result<(), TraceError> {
431 self.events
432 .lock()
433 .await
434 .push(format!("event:{span_id}:{name}:{attributes}",));
435 Ok(())
436 }
437 }
438
439 #[tokio::test]
440 async fn agents_can_record_traces() {
441 let llm = Arc::new(MockLlm::new(vec![LlmResponse {
442 content: "ok".into(),
443 tool_calls: vec![],
444 }]));
445 let config = AgentConfig {
446 instruction: "helper".into(),
447 provider: ProviderKind::OpenAi,
448 model: "gpt-4o".into(),
449 structured_output: None,
450 tools: Vec::new(),
451 };
452 let tracer = Arc::new(RecordingTracer::default());
453 let agent = Agent::with_llm(llm.clone(), config).with_tracer(tracer.clone());
454 agent.invoke("Hi").await.expect("agent");
455 let events = tracer.events.lock().await;
456 assert!(events.iter().any(|e| e.starts_with("start:agent.invoke")));
457 assert!(events.iter().any(|e| e.starts_with("end:agent.invoke")));
458 }
459}