1use super::prompt_request::{self, PromptRequest};
2use crate::{
3 agent::prompt_request::streaming::StreamingPromptRequest,
4 completion::{
5 Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6 GetTokenUsage, Message, Prompt, PromptError,
7 },
8 message::ToolChoice,
9 streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
10 tool::server::ToolServerHandle,
11 vector_store::{VectorStoreError, request::VectorSearchRequest},
12 wasm_compat::WasmCompatSend,
13};
14use futures::{StreamExt, TryStreamExt, stream};
15use std::{collections::HashMap, sync::Arc};
16use tokio::sync::RwLock;
17
18const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
19
20pub type DynamicContextStore =
21 Arc<RwLock<Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>>>;
22
23#[derive(Clone)]
44#[non_exhaustive]
45pub struct Agent<M>
46where
47 M: CompletionModel,
48{
49 pub name: Option<String>,
51 pub description: Option<String>,
53 pub model: Arc<M>,
55 pub preamble: Option<String>,
57 pub static_context: Vec<Document>,
59 pub temperature: Option<f64>,
61 pub max_tokens: Option<u64>,
63 pub additional_params: Option<serde_json::Value>,
65 pub tool_server_handle: ToolServerHandle,
66 pub dynamic_context: DynamicContextStore,
68 pub tool_choice: Option<ToolChoice>,
70}
71
72impl<M> Agent<M>
73where
74 M: CompletionModel,
75{
76 pub(crate) fn name(&self) -> &str {
78 self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
79 }
80}
81
82impl<M> Completion<M> for Agent<M>
83where
84 M: CompletionModel,
85{
86 async fn completion(
87 &self,
88 prompt: impl Into<Message> + WasmCompatSend,
89 chat_history: Vec<Message>,
90 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
91 let prompt = prompt.into();
92
93 let rag_text = prompt.rag_text();
95 let rag_text = rag_text.or_else(|| {
96 chat_history
97 .iter()
98 .rev()
99 .find_map(|message| message.rag_text())
100 });
101
102 let completion_request = self
103 .model
104 .completion_request(prompt)
105 .messages(chat_history)
106 .temperature_opt(self.temperature)
107 .max_tokens_opt(self.max_tokens)
108 .additional_params_opt(self.additional_params.clone())
109 .documents(self.static_context.clone());
110 let completion_request = if let Some(preamble) = &self.preamble {
111 completion_request.preamble(preamble.to_owned())
112 } else {
113 completion_request
114 };
115
116 let agent = match &rag_text {
118 Some(text) => {
119 let dynamic_context = stream::iter(self.dynamic_context.read().await.iter())
120 .then(|(num_sample, index)| async {
121 let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
122 Ok::<_, VectorStoreError>(
123 index
124 .top_n(req)
125 .await?
126 .into_iter()
127 .map(|(_, id, doc)| {
128 let text = serde_json::to_string_pretty(&doc)
130 .unwrap_or_else(|_| doc.to_string());
131
132 Document {
133 id,
134 text,
135 additional_props: HashMap::new(),
136 }
137 })
138 .collect::<Vec<_>>(),
139 )
140 })
141 .try_fold(vec![], |mut acc, docs| async {
142 acc.extend(docs);
143 Ok(acc)
144 })
145 .await
146 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
147
148 let tooldefs = self
149 .tool_server_handle
150 .get_tool_defs(Some(text.to_string()))
151 .await
152 .map_err(|_| {
153 CompletionError::RequestError("Failed to get tool definitions".into())
154 })?;
155
156 completion_request
157 .documents(dynamic_context)
158 .tools(tooldefs)
159 }
160 None => {
161 let tooldefs = self
162 .tool_server_handle
163 .get_tool_defs(None)
164 .await
165 .map_err(|_| {
166 CompletionError::RequestError("Failed to get tool definitions".into())
167 })?;
168
169 completion_request.tools(tooldefs)
170 }
171 };
172
173 Ok(agent)
174 }
175}
176
177#[allow(refining_impl_trait)]
185impl<M> Prompt for Agent<M>
186where
187 M: CompletionModel,
188{
189 fn prompt(
190 &self,
191 prompt: impl Into<Message> + WasmCompatSend,
192 ) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
193 PromptRequest::new(self, prompt)
194 }
195}
196
197#[allow(refining_impl_trait)]
198impl<M> Prompt for &Agent<M>
199where
200 M: CompletionModel,
201{
202 #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
203 fn prompt(
204 &self,
205 prompt: impl Into<Message> + WasmCompatSend,
206 ) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
207 PromptRequest::new(*self, prompt)
208 }
209}
210
211#[allow(refining_impl_trait)]
212impl<M> Chat for Agent<M>
213where
214 M: CompletionModel,
215{
216 #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
217 async fn chat(
218 &self,
219 prompt: impl Into<Message> + WasmCompatSend,
220 mut chat_history: Vec<Message>,
221 ) -> Result<String, PromptError> {
222 PromptRequest::new(self, prompt)
223 .with_history(&mut chat_history)
224 .await
225 }
226}
227
228impl<M> StreamingCompletion<M> for Agent<M>
229where
230 M: CompletionModel,
231{
232 async fn stream_completion(
233 &self,
234 prompt: impl Into<Message> + WasmCompatSend,
235 chat_history: Vec<Message>,
236 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
237 self.completion(prompt, chat_history).await
240 }
241}
242
243impl<M> StreamingPrompt<M, M::StreamingResponse> for Agent<M>
244where
245 M: CompletionModel + 'static,
246 M::StreamingResponse: GetTokenUsage,
247{
248 fn stream_prompt(
249 &self,
250 prompt: impl Into<Message> + WasmCompatSend,
251 ) -> StreamingPromptRequest<M, ()> {
252 let arc = Arc::new(self.clone());
253 StreamingPromptRequest::new(arc, prompt)
254 }
255}
256
257impl<M> StreamingChat<M, M::StreamingResponse> for Agent<M>
258where
259 M: CompletionModel + 'static,
260 M::StreamingResponse: GetTokenUsage,
261{
262 fn stream_chat(
263 &self,
264 prompt: impl Into<Message> + WasmCompatSend,
265 chat_history: Vec<Message>,
266 ) -> StreamingPromptRequest<M, ()> {
267 let arc = Arc::new(self.clone());
268 StreamingPromptRequest::new(arc, prompt).with_history(chat_history)
269 }
270}