1use crate::{
2 instruction,
3 opentelemetry::{start_tool_span, trace_agent_run, trace_agent_stream, AgentSpanMethod},
4 toolkit::ToolkitSession,
5 types::{AgentItemTool, AgentStream, AgentStreamEvent},
6 AgentError, AgentItem, AgentParams, AgentResponse, AgentStreamItemEvent, AgentTool,
7};
8use async_stream::try_stream;
9use futures::{
10 future::{join_all, try_join_all},
11 lock::Mutex,
12 stream::StreamExt,
13};
14use llm_sdk::{
15 boxed_stream::BoxedStream, LanguageModelInput, Message, ModelResponse, Part, StreamAccumulator,
16 ToolCallPart, ToolResultPart,
17};
18use std::{collections::HashSet, sync::Arc};
19
20pub struct RunSession<TCtx> {
29 params: Arc<AgentParams<TCtx>>,
31 context: Arc<TCtx>,
33 system_prompt: Option<String>,
35 toolkit_sessions: Arc<Vec<Box<dyn ToolkitSession<TCtx> + Send + Sync>>>,
37}
38
39impl<TCtx> RunSession<TCtx>
40where
41 TCtx: Send + Sync + 'static,
42{
43 #[allow(clippy::unused_async)]
45 #[allow(clippy::too_many_arguments)]
46 pub async fn new(params: Arc<AgentParams<TCtx>>, context: TCtx) -> Result<Self, AgentError> {
47 let system_prompt = if params.instructions.is_empty() {
48 None
49 } else {
50 Some(
51 instruction::get_prompt(¶ms.instructions, &context)
52 .await
53 .map_err(AgentError::Init)?,
54 )
55 };
56
57 let toolkit_sessions = Self::initialize(¶ms, &context).await?;
58
59 Ok(Self {
60 params,
61 context: Arc::new(context),
62 system_prompt,
63 toolkit_sessions: Arc::new(toolkit_sessions),
64 })
65 }
66
67 #[allow(clippy::too_many_lines)]
89 fn process<'a>(
90 &'a self,
91 run_state: &'a RunState,
92 tools: Vec<Arc<dyn AgentTool<TCtx>>>,
93 ) -> BoxedStream<'a, Result<ProcessEvents, AgentError>> {
94 let context_val = self.context.clone();
95 let stream = try_stream! {
96 let items = run_state.items().await;
97 let last_item = items.last().cloned().ok_or_else(|| {
99 AgentError::Invariant("No items in the run state.".to_string())
100 })?;
101
102 let mut content: Option<Vec<Part>> = None;
103 let mut processed_tool_call_ids: HashSet<String> = HashSet::new();
104
105 match last_item {
106 AgentItem::Model(model_response) => {
107 content = Some(model_response.content);
110 }
111 AgentItem::Message(message) => match message {
112 Message::Assistant(assistant_message) => {
113 content = Some(assistant_message.content);
116 }
117 Message::User(_) => {
118 yield ProcessEvents::Next;
121 return;
122 }
123 Message::Tool(tool_message) => {
124 for part in tool_message.content {
127 if let Part::ToolResult(result) = part {
128 processed_tool_call_ids.insert(result.tool_call_id);
129 }
130 }
131
132 let previous_item = items
134 .len()
135 .checked_sub(2)
136 .and_then(|idx| items.get(idx))
137 .cloned()
138 .ok_or_else(|| {
139 AgentError::Invariant(
140 "No preceding assistant content found before tool results.".to_string(),
141 )
142 })?;
143
144 let resolved = match previous_item {
145 AgentItem::Model(model_response) => model_response.content,
146 AgentItem::Message(prev_message) => match prev_message {
147 Message::Assistant(assistant_message) => assistant_message.content,
148 _ => {
149 Err(AgentError::Invariant(
150 "Expected a model item or assistant message before tool results.".to_string(),
151 ))?
152 }
153 },
154 AgentItem::Tool(_) => {
155 Err(AgentError::Invariant(
156 "Expected a model item or assistant message before tool results.".to_string(),
157 ))?
158 }
159 };
160 content = Some(resolved);
161 }
162 },
163 AgentItem::Tool(_) => {
164 for item in items.into_iter().rev() {
169 match item {
170 AgentItem::Tool(tool_item) => {
171 processed_tool_call_ids.insert(tool_item.tool_call_id);
172 }
174 AgentItem::Model(model_response) => {
175 content = Some(model_response.content);
177 break;
178 }
179 AgentItem::Message(message) => match message {
180 Message::Tool(tool_message) => {
181 for part in tool_message.content {
183 if let Part::ToolResult(result) = part {
184 processed_tool_call_ids.insert(result.tool_call_id);
185 }
186 }
187 }
189 Message::Assistant(assistant_message) => {
190 content = Some(assistant_message.content);
192 break;
193 }
194 Message::User(_) => {
195 Err(AgentError::Invariant(
196 "Expected a model item or assistant message before tool results.".to_string(),
197 ))?;
198 }
199 },
200 }
201 }
202 }
203 }
204
205 let content = content
206 .filter(|v| !v.is_empty())
207 .ok_or_else(|| AgentError::Invariant(
208 "No assistant content found to process.".to_string(),
209 ))?;
210
211 let tool_call_parts: Vec<ToolCallPart> = content
212 .iter()
213 .filter_map(|part| {
214 if let Part::ToolCall(tool_call) = part {
215 Some(tool_call.clone())
216 } else {
217 None
218 }
219 })
220 .collect();
221
222
223 if tool_call_parts.is_empty() {
225 yield ProcessEvents::Response(content);
226 return;
227 }
228
229 for tool_call_part in tool_call_parts {
230 if processed_tool_call_ids.contains(&tool_call_part.tool_call_id)
231 {
232 continue;
234 }
235
236 let ToolCallPart {
237 tool_call_id,
238 tool_name,
239 args,
240 ..
241 } = tool_call_part;
242
243 let agent_tool = tools
244 .iter()
245 .find(|tool| tool.name() == tool_name)
246 .ok_or_else(|| {
247 AgentError::Invariant(format!("Tool {tool_name} not found for tool call"))
248 })?;
249
250 let tool_name_value = agent_tool.name();
251 let tool_description = agent_tool.description();
252 let tool_res = start_tool_span(
253 &tool_call_id,
254 &tool_name_value,
255 &tool_description,
256 agent_tool.execute(args.clone(), &context_val, run_state),
257 )
258 .await
259 .map_err(AgentError::ToolExecution)?;
260
261 let item = AgentItemTool {
262 tool_call_id,
263 tool_name,
264 input: args,
265 output: tool_res.content,
266 is_error: tool_res.is_error,
267 };
268
269 yield ProcessEvents::Item(AgentItem::Tool(item));
270 }
271
272 yield ProcessEvents::Next;
273 };
274
275 BoxedStream::from_stream(stream)
276 }
277
278 pub async fn run(&self, request: RunSessionRequest) -> Result<AgentResponse, AgentError> {
280 let RunSessionRequest { input } = request;
281
282 trace_agent_run(&self.params.name, AgentSpanMethod::Run, async move {
283 let state = RunState::new(input, self.params.max_turns);
284 let mut tools = self.get_tools();
285
286 loop {
287 let mut process_stream = self.process(&state, tools);
288
289 while let Some(event) = process_stream.next().await {
290 let event = event?;
291 match event {
292 ProcessEvents::Item(item) => {
293 state.append_item(item).await;
294 }
295 ProcessEvents::Response(final_content) => {
296 return Ok(state.create_response(final_content).await);
297 }
298 ProcessEvents::Next => {
299 state.turn().await?;
300 break;
301 }
302 }
303 }
304
305 let (input, next_tools) = self.get_turn_params(&state).await?;
306 tools = next_tools;
307
308 let model_response = self.params.model.generate(input).await?;
309 state.append_model_response(model_response).await;
310 }
311 })
312 .await
313 }
314
315 pub fn run_stream(&self, request: RunSessionRequest) -> Result<AgentStream, AgentError> {
317 let RunSessionRequest { input } = request;
318 let state = Arc::new(RunState::new(input, self.params.max_turns));
319
320 let session = Arc::new(Self {
321 params: self.params.clone(),
322 context: self.context.clone(),
323 system_prompt: self.system_prompt.clone(),
324 toolkit_sessions: self.toolkit_sessions.clone(),
325 });
326
327 let stream = async_stream::try_stream! {
328 let mut tools = session.get_tools();
329
330 loop {
331 let mut process_stream = session.process(&state, tools);
332
333 while let Some(event) = process_stream.next().await {
334 let event = event?;
335
336 match event {
337 ProcessEvents::Item(item) => {
338 let index = state.append_item(item.clone()).await;
339 yield AgentStreamEvent::Item(AgentStreamItemEvent { index, item });
340 }
341 ProcessEvents::Response(final_content) => {
342 let response = state.create_response(final_content).await;
343 yield AgentStreamEvent::Response(response);
344 return;
345 }
346 ProcessEvents::Next => {
347 state.turn().await?;
348 break;
349 }
350 }
351 }
352
353 let (input, next_tools) = session.get_turn_params(&state).await?;
354 tools = next_tools;
355
356 let mut model_stream = session.params.model.stream(input).await?;
357
358 let mut accumulator = StreamAccumulator::new();
359
360 while let Some(partial) = model_stream.next().await {
361 let partial = partial?;
362
363 accumulator.add_partial(partial.clone()).map_err(|e| {
364 AgentError::Invariant(format!("Failed to accumulate stream: {e}"))
365 })?;
366
367 yield AgentStreamEvent::Partial(partial);
368 }
369
370 let model_response = accumulator.compute_response()?;
371
372 let (item, index) = state.append_model_response(model_response).await;
373 yield AgentStreamEvent::Item(AgentStreamItemEvent { index, item });
374 }
375 };
376
377 Ok(trace_agent_stream(&self.params.name, stream))
378 }
379
380 pub async fn close(self) -> Result<(), AgentError> {
381 if let Ok(toolkit_sessions) = Arc::try_unwrap(self.toolkit_sessions) {
382 let _ = join_all(
383 toolkit_sessions
384 .into_iter()
385 .map(super::toolkit::ToolkitSession::close),
386 )
387 .await;
388 }
389
390 Ok(())
391 }
392
393 async fn initialize(
394 params: &AgentParams<TCtx>,
395 context: &TCtx,
396 ) -> Result<Vec<Box<dyn ToolkitSession<TCtx> + Send + Sync>>, AgentError> {
397 let toolkit_sessions = if params.toolkits.is_empty() {
398 Vec::new()
399 } else {
400 let futures = params.toolkits.iter().map(|toolkit| async move {
401 toolkit
402 .create_session(context)
403 .await
404 .map_err(AgentError::Init)
405 });
406
407 try_join_all(futures).await?
408 };
409 Ok(toolkit_sessions)
410 }
411
412 async fn get_turn_params(
413 &self,
414 state: &RunState,
415 ) -> Result<(LanguageModelInput, Vec<Arc<dyn AgentTool<TCtx>>>), AgentError> {
416 let mut system_prompts = Vec::new();
417 if let Some(prompt) = &self.system_prompt {
418 if !prompt.is_empty() {
419 system_prompts.push(prompt.clone());
420 }
421 }
422
423 for session in self.toolkit_sessions.iter() {
424 if let Some(prompt) = session.system_prompt() {
425 if !prompt.is_empty() {
426 system_prompts.push(prompt);
427 }
428 }
429 }
430
431 let tools = self.get_tools();
432
433 let mut input = LanguageModelInput {
434 messages: state.get_turn_messages().await,
435 response_format: Some(self.params.response_format.clone()),
436 temperature: self.params.temperature,
437 top_p: self.params.top_p,
438 top_k: self.params.top_k,
439 presence_penalty: self.params.presence_penalty,
440 frequency_penalty: self.params.frequency_penalty,
441 modalities: self.params.modalities.clone(),
442 reasoning: self.params.reasoning.clone(),
443 audio: self.params.audio.clone(),
444 ..Default::default()
445 };
446
447 if !system_prompts.is_empty() {
448 input.system_prompt = Some(system_prompts.join("\n"));
449 }
450
451 if !tools.is_empty() {
452 let sdk_tools = tools.iter().map(|tool| tool.as_ref().into()).collect();
453 input.tools = Some(sdk_tools);
454 }
455
456 Ok((input, tools))
457 }
458
459 fn get_tools(&self) -> Vec<Arc<dyn AgentTool<TCtx>>> {
460 let mut tools: Vec<Arc<dyn AgentTool<TCtx>>> = self.params.tools.clone();
461 for session in self.toolkit_sessions.iter() {
462 let toolkit_tools = session.tools();
463 tools.extend(toolkit_tools);
464 }
465 tools
466 }
467}
468pub struct RunSessionRequest {
470 pub input: Vec<AgentItem>,
472}
473
474enum ProcessEvents {
475 Item(AgentItem),
477 Response(Vec<Part>),
479 Next,
481}
482
483pub struct RunState {
484 max_turns: usize,
485 input: Vec<AgentItem>,
486
487 pub current_turn: Arc<Mutex<usize>>,
489 output: Arc<Mutex<Vec<AgentItem>>>,
491}
492
493impl RunState {
494 #[must_use]
495 fn new(input: Vec<AgentItem>, max_turns: usize) -> Self {
496 Self {
497 max_turns,
498 input,
499 current_turn: Arc::new(Mutex::new(0)),
500 output: Arc::new(Mutex::new(vec![])),
501 }
502 }
503
504 async fn turn(&self) -> Result<(), AgentError> {
507 let mut current_turn = self.current_turn.lock().await;
508 *current_turn += 1;
509 if *current_turn > self.max_turns {
510 return Err(AgentError::MaxTurnsExceeded(self.max_turns));
511 }
512 Ok(())
513 }
514
515 async fn append_item(&self, item: AgentItem) -> usize {
518 let mut output: futures::lock::MutexGuard<'_, Vec<AgentItem>> = self.output.lock().await;
519 output.push(item);
520 output.len() - 1
521 }
522
523 pub async fn items(&self) -> Vec<AgentItem> {
525 let output = self.output.lock().await;
526 self.input
527 .iter()
528 .cloned()
529 .chain(output.iter().cloned())
530 .collect()
531 }
532
533 async fn append_model_response(&self, response: ModelResponse) -> (AgentItem, usize) {
536 let mut output = self.output.lock().await;
537 let item = AgentItem::Model(response);
538 output.push(item.clone());
539 (item, output.len() - 1)
540 }
541
542 #[must_use]
544 async fn get_turn_messages(&self) -> Vec<Message> {
545 let output = self.output.lock().await;
546 let mut messages: Vec<Message> = Vec::new();
547 let iter = self.input.iter().cloned().chain(output.iter().cloned());
548
549 for item in iter {
550 match item {
551 AgentItem::Message(msg) => messages.push(msg),
552 AgentItem::Model(model_response) => {
553 messages.push(Message::assistant(model_response.content));
554 }
555 AgentItem::Tool(tool) => {
556 let tool_part: Part =
557 ToolResultPart::new(tool.tool_call_id, tool.tool_name, tool.output)
558 .with_is_error(tool.is_error)
559 .into();
560
561 match messages.last_mut() {
562 Some(Message::Tool(last_tool_message)) => {
563 last_tool_message.content.push(tool_part);
564 }
565 _ => {
566 messages.push(Message::tool(vec![tool_part]));
567 }
568 }
569 }
570 }
571 }
572
573 messages
574 }
575
576 #[must_use]
577 async fn create_response(&self, final_content: Vec<Part>) -> AgentResponse {
578 let output = self.output.lock().await;
579 AgentResponse {
580 content: final_content,
581 output: output.clone(),
582 }
583 }
584}