1use super::prompt_request::PromptRequest;
2use crate::{
3 completion::{
4 Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
5 Message, Prompt, PromptError,
6 },
7 streaming::{StreamingChat, StreamingCompletion, StreamingCompletionResponse, StreamingPrompt},
8 tool::ToolSet,
9 vector_store::VectorStoreError,
10};
11use futures::{stream, StreamExt, TryStreamExt};
12use std::collections::HashMap;
13
14pub struct Agent<M: CompletionModel> {
35 pub model: M,
37 pub preamble: String,
39 pub static_context: Vec<Document>,
41 pub static_tools: Vec<String>,
43 pub temperature: Option<f64>,
45 pub max_tokens: Option<u64>,
47 pub additional_params: Option<serde_json::Value>,
49 pub dynamic_context: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
51 pub dynamic_tools: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
53 pub tools: ToolSet,
55}
56
57impl<M: CompletionModel> Completion<M> for Agent<M> {
58 async fn completion(
59 &self,
60 prompt: impl Into<Message> + Send,
61 chat_history: Vec<Message>,
62 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
63 let prompt = prompt.into();
64
65 let rag_text = prompt.rag_text();
67 let rag_text = rag_text.or_else(|| {
68 chat_history
69 .iter()
70 .rev()
71 .find_map(|message| message.rag_text())
72 });
73
74 let completion_request = self
75 .model
76 .completion_request(prompt)
77 .preamble(self.preamble.clone())
78 .messages(chat_history)
79 .temperature_opt(self.temperature)
80 .max_tokens_opt(self.max_tokens)
81 .additional_params_opt(self.additional_params.clone())
82 .documents(self.static_context.clone());
83
84 let agent = match &rag_text {
86 Some(text) => {
87 let dynamic_context = stream::iter(self.dynamic_context.iter())
88 .then(|(num_sample, index)| async {
89 Ok::<_, VectorStoreError>(
90 index
91 .top_n(text, *num_sample)
92 .await?
93 .into_iter()
94 .map(|(_, id, doc)| {
95 let text = serde_json::to_string_pretty(&doc)
97 .unwrap_or_else(|_| doc.to_string());
98
99 Document {
100 id,
101 text,
102 additional_props: HashMap::new(),
103 }
104 })
105 .collect::<Vec<_>>(),
106 )
107 })
108 .try_fold(vec![], |mut acc, docs| async {
109 acc.extend(docs);
110 Ok(acc)
111 })
112 .await
113 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
114
115 let dynamic_tools = stream::iter(self.dynamic_tools.iter())
116 .then(|(num_sample, index)| async {
117 Ok::<_, VectorStoreError>(
118 index
119 .top_n_ids(text, *num_sample)
120 .await?
121 .into_iter()
122 .map(|(_, id)| id)
123 .collect::<Vec<_>>(),
124 )
125 })
126 .try_fold(vec![], |mut acc, docs| async {
127 for doc in docs {
128 if let Some(tool) = self.tools.get(&doc) {
129 acc.push(tool.definition(text.into()).await)
130 } else {
131 tracing::warn!("Tool implementation not found in toolset: {}", doc);
132 }
133 }
134 Ok(acc)
135 })
136 .await
137 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
138
139 let static_tools = stream::iter(self.static_tools.iter())
140 .filter_map(|toolname| async move {
141 if let Some(tool) = self.tools.get(toolname) {
142 Some(tool.definition(text.into()).await)
143 } else {
144 tracing::warn!(
145 "Tool implementation not found in toolset: {}",
146 toolname
147 );
148 None
149 }
150 })
151 .collect::<Vec<_>>()
152 .await;
153
154 completion_request
155 .documents(dynamic_context)
156 .tools([static_tools.clone(), dynamic_tools].concat())
157 }
158 None => {
159 let static_tools = stream::iter(self.static_tools.iter())
160 .filter_map(|toolname| async move {
161 if let Some(tool) = self.tools.get(toolname) {
162 Some(tool.definition("".into()).await)
164 } else {
165 tracing::warn!(
166 "Tool implementation not found in toolset: {}",
167 toolname
168 );
169 None
170 }
171 })
172 .collect::<Vec<_>>()
173 .await;
174
175 completion_request.tools(static_tools)
176 }
177 };
178
179 Ok(agent)
180 }
181}
182
183#[allow(refining_impl_trait)]
191impl<M: CompletionModel> Prompt for Agent<M> {
192 fn prompt(&self, prompt: impl Into<Message> + Send) -> PromptRequest<M> {
193 PromptRequest::new(self, prompt)
194 }
195}
196
197#[allow(refining_impl_trait)]
198impl<M: CompletionModel> Prompt for &Agent<M> {
199 fn prompt(&self, prompt: impl Into<Message> + Send) -> PromptRequest<M> {
200 PromptRequest::new(*self, prompt)
201 }
202}
203
204#[allow(refining_impl_trait)]
205impl<M: CompletionModel> Chat for Agent<M> {
206 async fn chat(
207 &self,
208 prompt: impl Into<Message> + Send,
209 chat_history: Vec<Message>,
210 ) -> Result<String, PromptError> {
211 let mut cloned_history = chat_history.clone();
212 PromptRequest::new(self, prompt)
213 .with_history(&mut cloned_history)
214 .await
215 }
216}
217
218impl<M: CompletionModel> StreamingCompletion<M> for Agent<M> {
219 async fn stream_completion(
220 &self,
221 prompt: impl Into<Message> + Send,
222 chat_history: Vec<Message>,
223 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
224 self.completion(prompt, chat_history).await
227 }
228}
229
230impl<M: CompletionModel> StreamingPrompt<M::StreamingResponse> for Agent<M> {
231 async fn stream_prompt(
232 &self,
233 prompt: impl Into<Message> + Send,
234 ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
235 self.stream_chat(prompt, vec![]).await
236 }
237}
238
239impl<M: CompletionModel> StreamingChat<M::StreamingResponse> for Agent<M> {
240 async fn stream_chat(
241 &self,
242 prompt: impl Into<Message> + Send,
243 chat_history: Vec<Message>,
244 ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
245 self.stream_completion(prompt, chat_history)
246 .await?
247 .stream()
248 .await
249 }
250}