1use crate::{
2 run::{RunSession, RunSessionRequest},
3 types::AgentStream,
4 AgentError, AgentParams, AgentRequest, AgentResponse,
5};
6use futures::stream::StreamExt;
7use llm_sdk::LanguageModel;
8use std::sync::Arc;
9
10pub struct Agent<TCtx> {
11 pub name: String,
14 params: Arc<AgentParams<TCtx>>,
15}
16
17impl<TCtx> Agent<TCtx>
18where
19 TCtx: Send + Sync + 'static,
20{
21 #[must_use]
22 pub fn new(params: AgentParams<TCtx>) -> Self {
23 Self {
24 name: params.name.clone(),
25 params: Arc::new(params),
26 }
27 }
28 pub async fn run(&self, request: AgentRequest<TCtx>) -> Result<AgentResponse, AgentError> {
31 let AgentRequest { input, context } = request;
32 let run_session = self.create_session(context).await?;
33 let result = run_session.run(RunSessionRequest { input }).await;
34 match result {
35 Ok(response) => {
36 if let Err(close_err) = run_session.close().await {
37 Err(close_err)
38 } else {
39 Ok(response)
40 }
41 }
42 Err(err) => {
43 let _ = run_session.close().await;
44 Err(err)
45 }
46 }
47 }
48
49 pub async fn run_stream(&self, request: AgentRequest<TCtx>) -> Result<AgentStream, AgentError> {
52 let AgentRequest { input, context } = request;
53 let run_session = Arc::new(self.create_session(context).await?);
54 let mut stream = match run_session.clone().run_stream(RunSessionRequest { input }) {
55 Ok(stream) => stream,
56 Err(err) => {
57 if let Ok(session) = Arc::try_unwrap(run_session) {
58 let _ = session.close().await;
59 }
60 return Err(err);
61 }
62 };
63
64 let wrapped_stream = async_stream::stream! {
65 let run_session = run_session;
66 while let Some(item) = stream.next().await {
67 yield item;
68 }
69 if let Ok(session) = Arc::try_unwrap(run_session) {
70 if let Err(close_err) = session.close().await {
71 yield Err(close_err);
72 }
73 }
74 };
75
76 Ok(AgentStream::from_stream(wrapped_stream))
77 }
78
79 pub async fn create_session(&self, context: TCtx) -> Result<RunSession<TCtx>, AgentError> {
81 RunSession::new(self.params.clone(), context).await
82 }
83
84 pub fn builder(name: &str, model: Arc<dyn LanguageModel + Send + Sync>) -> AgentParams<TCtx> {
85 AgentParams::new(name, model)
86 }
87}