llm_agent/
agent.rs

1use crate::{
2    run::{RunSession, RunSessionRequest},
3    types::AgentStream,
4    AgentError, AgentParams, AgentRequest, AgentResponse,
5};
6use futures::stream::StreamExt;
7use llm_sdk::LanguageModel;
8use std::sync::Arc;
9
10pub struct Agent<TCtx> {
11    /// A unique name for the agent.
12    /// The name can only contain letters and underscores.
13    pub name: String,
14    params: Arc<AgentParams<TCtx>>,
15}
16
17impl<TCtx> Agent<TCtx>
18where
19    TCtx: Send + Sync + 'static,
20{
21    #[must_use]
22    pub fn new(params: AgentParams<TCtx>) -> Self {
23        Self {
24            name: params.name.clone(),
25            params: Arc::new(params),
26        }
27    }
28    /// Create a one-time run of the agent and generate a response.
29    /// A session is created for the run and cleaned up afterwards.
30    pub async fn run(&self, request: AgentRequest<TCtx>) -> Result<AgentResponse, AgentError> {
31        let AgentRequest { input, context } = request;
32        let run_session = self.create_session(context).await?;
33        let result = run_session.run(RunSessionRequest { input }).await;
34        match result {
35            Ok(response) => {
36                if let Err(close_err) = run_session.close().await {
37                    Err(close_err)
38                } else {
39                    Ok(response)
40                }
41            }
42            Err(err) => {
43                let _ = run_session.close().await;
44                Err(err)
45            }
46        }
47    }
48
49    /// Create a one-time streaming run of the agent and generate a response.
50    /// A session is created for the run and cleaned up afterwards.
51    pub async fn run_stream(&self, request: AgentRequest<TCtx>) -> Result<AgentStream, AgentError> {
52        let AgentRequest { input, context } = request;
53        let run_session = Arc::new(self.create_session(context).await?);
54        let mut stream = match run_session.clone().run_stream(RunSessionRequest { input }) {
55            Ok(stream) => stream,
56            Err(err) => {
57                if let Ok(session) = Arc::try_unwrap(run_session) {
58                    let _ = session.close().await;
59                }
60                return Err(err);
61            }
62        };
63
64        let wrapped_stream = async_stream::stream! {
65            let run_session = run_session;
66            while let Some(item) = stream.next().await {
67                yield item;
68            }
69            if let Ok(session) = Arc::try_unwrap(run_session) {
70                if let Err(close_err) = session.close().await {
71                    yield Err(close_err);
72                }
73            }
74        };
75
76        Ok(AgentStream::from_stream(wrapped_stream))
77    }
78
79    /// Create a session for stateful multiple runs of the agent.
80    pub async fn create_session(&self, context: TCtx) -> Result<RunSession<TCtx>, AgentError> {
81        RunSession::new(self.params.clone(), context).await
82    }
83
84    pub fn builder(name: &str, model: Arc<dyn LanguageModel + Send + Sync>) -> AgentParams<TCtx> {
85        AgentParams::new(name, model)
86    }
87}