llm_agent/
run.rs

1use crate::{
2    instruction,
3    opentelemetry::{start_tool_span, trace_agent_run, trace_agent_stream, AgentSpanMethod},
4    toolkit::ToolkitSession,
5    types::{AgentItemTool, AgentStream, AgentStreamEvent},
6    AgentError, AgentItem, AgentParams, AgentResponse, AgentStreamItemEvent, AgentTool,
7};
8use async_stream::try_stream;
9use futures::{
10    future::{join_all, try_join_all},
11    lock::Mutex,
12    stream::StreamExt,
13};
14use llm_sdk::{
15    boxed_stream::BoxedStream, LanguageModelInput, Message, ModelResponse, Part, StreamAccumulator,
16    ToolCallPart, ToolResultPart,
17};
18use std::{collections::HashSet, sync::Arc};
19
20/// Manages the run session for an agent.
21/// It initializes all necessary components for the agent to run
22/// and handles the execution of the agent's tasks.
23/// Once finished, the session cleans up any resources used during the run.
24/// The session can be reused in multiple runs. `RunSession` binds to a specific
25///
26/// context value that is used to resolve instructions and invoke tools, while
27/// input items remain per run and are supplied to each invocation.
28pub struct RunSession<TCtx> {
29    /// Agent configuration used during the run session.
30    params: Arc<AgentParams<TCtx>>,
31    /// The bound context value passed to instruction resolvers and tools.
32    context: Arc<TCtx>,
33    /// System prompt generated from the static instruction params.
34    system_prompt: Option<String>,
35    /// Toolkit sessions created for this run session.
36    toolkit_sessions: Arc<Vec<Box<dyn ToolkitSession<TCtx> + Send + Sync>>>,
37}
38
39impl<TCtx> RunSession<TCtx>
40where
41    TCtx: Send + Sync + 'static,
42{
43    /// Creates a new run session and initializes dependencies
44    #[allow(clippy::unused_async)]
45    #[allow(clippy::too_many_arguments)]
46    pub async fn new(params: Arc<AgentParams<TCtx>>, context: TCtx) -> Result<Self, AgentError> {
47        let system_prompt = if params.instructions.is_empty() {
48            None
49        } else {
50            Some(
51                instruction::get_prompt(&params.instructions, &context)
52                    .await
53                    .map_err(AgentError::Init)?,
54            )
55        };
56
57        let toolkit_sessions = Self::initialize(&params, &context).await?;
58
59        Ok(Self {
60            params,
61            context: Arc::new(context),
62            system_prompt,
63            toolkit_sessions: Arc::new(toolkit_sessions),
64        })
65    }
66
67    /// `process()` flow:
68    /// 1. Peek latest run item to locate assistant content.
69    ///
70    ///    1a. Tail is user message -> emit `Next`. Go to 3.
71    ///
72    ///    1b. Tail is tool/tool message -> gather processed ids, backtrack to
73    ///    assistant/model content. Go to 2.
74    ///
75    ///    1c. Tail is assistant/model -> use its content. Go to 2.
76    ///
77    /// 2. Scan assistant content for tool calls.
78    ///
79    ///    2a. Tool calls remaining -> execute unprocessed tools, emit each
80    ///    `Item`, then emit `Next`. Go to 3.
81    ///
82    ///    2b. No tool calls -> emit `Response`. Go to 4.
83    ///
84    /// 3. Outer loop: bump turn, refresh params, request model response, append
85    ///    it, then re-enter step 1.
86    ///
87    /// 4. Return final response to caller.
88    #[allow(clippy::too_many_lines)]
89    fn process<'a>(
90        &'a self,
91        run_state: &'a RunState,
92        tools: Vec<Arc<dyn AgentTool<TCtx>>>,
93    ) -> BoxedStream<'a, Result<ProcessEvents, AgentError>> {
94        let context_val = self.context.clone();
95        let stream = try_stream! {
96            let items = run_state.items().await;
97            // Examining the last items in the state determines the next step.
98            let last_item = items.last().cloned().ok_or_else(|| {
99                AgentError::Invariant("No items in the run state.".to_string())
100            })?;
101
102            let mut content: Option<Vec<Part>> = None;
103            let mut processed_tool_call_ids: HashSet<String> = HashSet::new();
104
105            match last_item {
106                AgentItem::Model(model_response) => {
107                    // ========== Case: Assistant Message [from AgentItemModelResponse] ==========
108                    // Last item is a model response, process it
109                    content = Some(model_response.content);
110                }
111                AgentItem::Message(message) => match message {
112                    Message::Assistant(assistant_message) => {
113                        // ========== Case: Assistant Message [from AgentItemMessage] ==========
114                        // Last item is an assistant message, process it
115                        content = Some(assistant_message.content);
116                    }
117                    Message::User(_) => {
118                        // ========== Case: User Message ==========
119                        // last item is a user message, so we need to generate a model response
120                        yield ProcessEvents::Next;
121                        return;
122                    }
123                    Message::Tool(tool_message) => {
124                        // ========== Case: Tool Results (from AgentItemMessage) ==========
125                        // Track the tool call ids that have been processed to avoid duplicate execution
126                        for part in tool_message.content {
127                            if let Part::ToolResult(result) = part {
128                                processed_tool_call_ids.insert(result.tool_call_id);
129                            }
130                        }
131
132                        // We are in the middle of processing tool results, the 2nd last item should be a model response
133                        let previous_item = items
134                            .len()
135                            .checked_sub(2)
136                            .and_then(|idx| items.get(idx))
137                            .cloned()
138                            .ok_or_else(|| {
139                                AgentError::Invariant(
140                                    "No preceding assistant content found before tool results.".to_string(),
141                                )
142                            })?;
143
144                        let resolved = match previous_item {
145                            AgentItem::Model(model_response) => model_response.content,
146                            AgentItem::Message(prev_message) => match prev_message {
147                                Message::Assistant(assistant_message) => assistant_message.content,
148                                _ => {
149                                    Err(AgentError::Invariant(
150                                        "Expected a model item or assistant message before tool results.".to_string(),
151                                    ))?
152                                }
153                            },
154                            AgentItem::Tool(_) => {
155                                Err(AgentError::Invariant(
156                                    "Expected a model item or assistant message before tool results.".to_string(),
157                                ))?
158                            }
159                        };
160                        content = Some(resolved);
161                    }
162                },
163                AgentItem::Tool(_) => {
164                    // ========== Case: Tool Results (from AgentItemTool) ==========
165                    // Each tool result is an individual item in this representation, so there could be other
166                    // AgentItemTool before this one. We loop backwards to find the first non-tool item while also
167                    // tracking the called tool ids to avoid duplicate execution
168                    for item in items.into_iter().rev() {
169                        match item {
170                            AgentItem::Tool(tool_item) => {
171                                processed_tool_call_ids.insert(tool_item.tool_call_id);
172                                // Continue searching for the originating model/assistant item
173                            }
174                            AgentItem::Model(model_response) => {
175                                // Found the originating model response
176                                content = Some(model_response.content);
177                                break;
178                            }
179                            AgentItem::Message(message) => match message {
180                                Message::Tool(tool_message) => {
181                                    // Collect all tool call ids in the tool message
182                                    for part in tool_message.content {
183                                        if let Part::ToolResult(result) = part {
184                                            processed_tool_call_ids.insert(result.tool_call_id);
185                                        }
186                                    }
187                                    // Continue searching for the originating model/assistant item
188                                }
189                                Message::Assistant(assistant_message) => {
190                                    // Found the originating model response
191                                    content = Some(assistant_message.content);
192                                    break;
193                                }
194                                Message::User(_) => {
195                                    Err(AgentError::Invariant(
196                                        "Expected a model item or assistant message before tool results.".to_string(),
197                                    ))?;
198                                }
199                            },
200                        }
201                    }
202                }
203            }
204
205            let content = content
206                .filter(|v| !v.is_empty())
207                .ok_or_else(|| AgentError::Invariant(
208                    "No assistant content found to process.".to_string(),
209                ))?;
210
211            let tool_call_parts: Vec<ToolCallPart> = content
212                .iter()
213                .filter_map(|part| {
214                    if let Part::ToolCall(tool_call) = part {
215                        Some(tool_call.clone())
216                    } else {
217                        None
218                    }
219                })
220                .collect();
221
222
223            // If no tool calls were found, return the model response as is
224            if tool_call_parts.is_empty() {
225                yield ProcessEvents::Response(content);
226                return;
227            }
228
229            for tool_call_part in tool_call_parts {
230                if processed_tool_call_ids.contains(&tool_call_part.tool_call_id)
231                {
232                    // Tool call has already been processed
233                    continue;
234                }
235
236                let ToolCallPart {
237                    tool_call_id,
238                    tool_name,
239                    args,
240                    ..
241                } = tool_call_part;
242
243                let agent_tool = tools
244                    .iter()
245                    .find(|tool| tool.name() == tool_name)
246                    .ok_or_else(|| {
247                        AgentError::Invariant(format!("Tool {tool_name} not found for tool call"))
248                    })?;
249
250                let tool_name_value = agent_tool.name();
251                let tool_description = agent_tool.description();
252                let tool_res = start_tool_span(
253                    &tool_call_id,
254                    &tool_name_value,
255                    &tool_description,
256                    agent_tool.execute(args.clone(), &context_val, run_state),
257                )
258                .await
259                .map_err(AgentError::ToolExecution)?;
260
261                let item = AgentItemTool {
262                    tool_call_id,
263                    tool_name,
264                    input: args,
265                    output: tool_res.content,
266                    is_error: tool_res.is_error,
267                };
268
269                yield ProcessEvents::Item(AgentItem::Tool(item));
270            }
271
272            yield ProcessEvents::Next;
273        };
274
275        BoxedStream::from_stream(stream)
276    }
277
278    /// Run a non-streaming execution of the agent.
279    pub async fn run(&self, request: RunSessionRequest) -> Result<AgentResponse, AgentError> {
280        let RunSessionRequest { input } = request;
281
282        trace_agent_run(&self.params.name, AgentSpanMethod::Run, async move {
283            let state = RunState::new(input, self.params.max_turns);
284            let mut tools = self.get_tools();
285
286            loop {
287                let mut process_stream = self.process(&state, tools);
288
289                while let Some(event) = process_stream.next().await {
290                    let event = event?;
291                    match event {
292                        ProcessEvents::Item(item) => {
293                            state.append_item(item).await;
294                        }
295                        ProcessEvents::Response(final_content) => {
296                            return Ok(state.create_response(final_content).await);
297                        }
298                        ProcessEvents::Next => {
299                            state.turn().await?;
300                            break;
301                        }
302                    }
303                }
304
305                let (input, next_tools) = self.get_turn_params(&state).await?;
306                tools = next_tools;
307
308                let model_response = self.params.model.generate(input).await?;
309                state.append_model_response(model_response).await;
310            }
311        })
312        .await
313    }
314
315    /// Run a streaming execution of the agent.
316    pub fn run_stream(&self, request: RunSessionRequest) -> Result<AgentStream, AgentError> {
317        let RunSessionRequest { input } = request;
318        let state = Arc::new(RunState::new(input, self.params.max_turns));
319
320        let session = Arc::new(Self {
321            params: self.params.clone(),
322            context: self.context.clone(),
323            system_prompt: self.system_prompt.clone(),
324            toolkit_sessions: self.toolkit_sessions.clone(),
325        });
326
327        let stream = async_stream::try_stream! {
328            let mut tools = session.get_tools();
329
330            loop {
331                let mut process_stream = session.process(&state, tools);
332
333                while let Some(event) = process_stream.next().await {
334                    let event = event?;
335
336                    match event {
337                        ProcessEvents::Item(item) => {
338                            let index = state.append_item(item.clone()).await;
339                            yield AgentStreamEvent::Item(AgentStreamItemEvent { index, item });
340                        }
341                        ProcessEvents::Response(final_content) => {
342                            let response = state.create_response(final_content).await;
343                            yield AgentStreamEvent::Response(response);
344                            return;
345                        }
346                        ProcessEvents::Next => {
347                            state.turn().await?;
348                            break;
349                        }
350                    }
351                }
352
353                let (input, next_tools) = session.get_turn_params(&state).await?;
354                tools = next_tools;
355
356                let mut model_stream = session.params.model.stream(input).await?;
357
358                let mut accumulator = StreamAccumulator::new();
359
360                while let Some(partial) = model_stream.next().await {
361                    let partial = partial?;
362
363                    accumulator.add_partial(partial.clone()).map_err(|e| {
364                        AgentError::Invariant(format!("Failed to accumulate stream: {e}"))
365                    })?;
366
367                    yield AgentStreamEvent::Partial(partial);
368                }
369
370                let model_response = accumulator.compute_response()?;
371
372                let (item, index) = state.append_model_response(model_response).await;
373                yield AgentStreamEvent::Item(AgentStreamItemEvent { index, item });
374            }
375        };
376
377        Ok(trace_agent_stream(&self.params.name, stream))
378    }
379
380    pub async fn close(self) -> Result<(), AgentError> {
381        if let Ok(toolkit_sessions) = Arc::try_unwrap(self.toolkit_sessions) {
382            let _ = join_all(
383                toolkit_sessions
384                    .into_iter()
385                    .map(super::toolkit::ToolkitSession::close),
386            )
387            .await;
388        }
389
390        Ok(())
391    }
392
393    async fn initialize(
394        params: &AgentParams<TCtx>,
395        context: &TCtx,
396    ) -> Result<Vec<Box<dyn ToolkitSession<TCtx> + Send + Sync>>, AgentError> {
397        let toolkit_sessions = if params.toolkits.is_empty() {
398            Vec::new()
399        } else {
400            let futures = params.toolkits.iter().map(|toolkit| async move {
401                toolkit
402                    .create_session(context)
403                    .await
404                    .map_err(AgentError::Init)
405            });
406
407            try_join_all(futures).await?
408        };
409        Ok(toolkit_sessions)
410    }
411
412    async fn get_turn_params(
413        &self,
414        state: &RunState,
415    ) -> Result<(LanguageModelInput, Vec<Arc<dyn AgentTool<TCtx>>>), AgentError> {
416        let mut system_prompts = Vec::new();
417        if let Some(prompt) = &self.system_prompt {
418            if !prompt.is_empty() {
419                system_prompts.push(prompt.clone());
420            }
421        }
422
423        for session in self.toolkit_sessions.iter() {
424            if let Some(prompt) = session.system_prompt() {
425                if !prompt.is_empty() {
426                    system_prompts.push(prompt);
427                }
428            }
429        }
430
431        let tools = self.get_tools();
432
433        let mut input = LanguageModelInput {
434            messages: state.get_turn_messages().await,
435            response_format: Some(self.params.response_format.clone()),
436            temperature: self.params.temperature,
437            top_p: self.params.top_p,
438            top_k: self.params.top_k,
439            presence_penalty: self.params.presence_penalty,
440            frequency_penalty: self.params.frequency_penalty,
441            modalities: self.params.modalities.clone(),
442            reasoning: self.params.reasoning.clone(),
443            audio: self.params.audio.clone(),
444            ..Default::default()
445        };
446
447        if !system_prompts.is_empty() {
448            input.system_prompt = Some(system_prompts.join("\n"));
449        }
450
451        if !tools.is_empty() {
452            let sdk_tools = tools.iter().map(|tool| tool.as_ref().into()).collect();
453            input.tools = Some(sdk_tools);
454        }
455
456        Ok((input, tools))
457    }
458
459    fn get_tools(&self) -> Vec<Arc<dyn AgentTool<TCtx>>> {
460        let mut tools: Vec<Arc<dyn AgentTool<TCtx>>> = self.params.tools.clone();
461        for session in self.toolkit_sessions.iter() {
462            let toolkit_tools = session.tools();
463            tools.extend(toolkit_tools);
464        }
465        tools
466    }
467}
468/// `RunSessionRequest` contains the input items used for a run.
469pub struct RunSessionRequest {
470    /// Input holds the items for this run, such as LLM messages.
471    pub input: Vec<AgentItem>,
472}
473
474enum ProcessEvents {
475    // Emit when a new item is generated
476    Item(AgentItem),
477    //  Emit when the final response is ready
478    Response(Vec<Part>),
479    // Emit when the loop should continue to the next iteration
480    Next,
481}
482
483pub struct RunState {
484    max_turns: usize,
485    input: Vec<AgentItem>,
486
487    /// The current turn number in the run.
488    pub current_turn: Arc<Mutex<usize>>,
489    /// All items generated during the run, such as new tool and model items
490    output: Arc<Mutex<Vec<AgentItem>>>,
491}
492
493impl RunState {
494    #[must_use]
495    fn new(input: Vec<AgentItem>, max_turns: usize) -> Self {
496        Self {
497            max_turns,
498            input,
499            current_turn: Arc::new(Mutex::new(0)),
500            output: Arc::new(Mutex::new(vec![])),
501        }
502    }
503
504    /// Mark a new turn in the conversation and throw an error if max turns
505    /// exceeded.
506    async fn turn(&self) -> Result<(), AgentError> {
507        let mut current_turn = self.current_turn.lock().await;
508        *current_turn += 1;
509        if *current_turn > self.max_turns {
510            return Err(AgentError::MaxTurnsExceeded(self.max_turns));
511        }
512        Ok(())
513    }
514
515    /// Add `AgentItems` to the run state and return the index of the added
516    /// item.
517    async fn append_item(&self, item: AgentItem) -> usize {
518        let mut output: futures::lock::MutexGuard<'_, Vec<AgentItem>> = self.output.lock().await;
519        output.push(item);
520        output.len() - 1
521    }
522
523    /// Return all items in the run, both input and output.
524    pub async fn items(&self) -> Vec<AgentItem> {
525        let output = self.output.lock().await;
526        self.input
527            .iter()
528            .cloned()
529            .chain(output.iter().cloned())
530            .collect()
531    }
532
533    /// Append a model response to the run state and return the created item and
534    /// its index.
535    async fn append_model_response(&self, response: ModelResponse) -> (AgentItem, usize) {
536        let mut output = self.output.lock().await;
537        let item = AgentItem::Model(response);
538        output.push(item.clone());
539        (item, output.len() - 1)
540    }
541
542    /// Get LLM messages to use in the `LanguageModelInput` for the turn
543    #[must_use]
544    async fn get_turn_messages(&self) -> Vec<Message> {
545        let output = self.output.lock().await;
546        let mut messages: Vec<Message> = Vec::new();
547        let iter = self.input.iter().cloned().chain(output.iter().cloned());
548
549        for item in iter {
550            match item {
551                AgentItem::Message(msg) => messages.push(msg),
552                AgentItem::Model(model_response) => {
553                    messages.push(Message::assistant(model_response.content));
554                }
555                AgentItem::Tool(tool) => {
556                    let tool_part: Part =
557                        ToolResultPart::new(tool.tool_call_id, tool.tool_name, tool.output)
558                            .with_is_error(tool.is_error)
559                            .into();
560
561                    match messages.last_mut() {
562                        Some(Message::Tool(last_tool_message)) => {
563                            last_tool_message.content.push(tool_part);
564                        }
565                        _ => {
566                            messages.push(Message::tool(vec![tool_part]));
567                        }
568                    }
569                }
570            }
571        }
572
573        messages
574    }
575
576    #[must_use]
577    async fn create_response(&self, final_content: Vec<Part>) -> AgentResponse {
578        let output = self.output.lock().await;
579        AgentResponse {
580            content: final_content,
581            output: output.clone(),
582        }
583    }
584}