1use std::collections::HashMap;
2
3use futures::{stream, StreamExt, TryStreamExt};
4
5use crate::{
6 completion::{
7 Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
8 Message, Prompt, PromptError,
9 },
10 streaming::{
11 StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
12 StreamingResult,
13 },
14 tool::ToolSet,
15 vector_store::VectorStoreError,
16};
17
18use super::prompt_request::PromptRequest;
19
20pub struct Agent<M: CompletionModel> {
41 pub model: M,
43 pub preamble: String,
45 pub static_context: Vec<Document>,
47 pub static_tools: Vec<String>,
49 pub temperature: Option<f64>,
51 pub max_tokens: Option<u64>,
53 pub additional_params: Option<serde_json::Value>,
55 pub dynamic_context: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
57 pub dynamic_tools: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
59 pub tools: ToolSet,
61}
62
63impl<M: CompletionModel> Completion<M> for Agent<M> {
64 async fn completion(
65 &self,
66 prompt: impl Into<Message> + Send,
67 chat_history: Vec<Message>,
68 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
69 let prompt = prompt.into();
70
71 let rag_text = prompt.rag_text();
73 let rag_text = rag_text.or_else(|| {
74 chat_history
75 .iter()
76 .rev()
77 .find_map(|message| message.rag_text())
78 });
79
80 let completion_request = self
81 .model
82 .completion_request(prompt)
83 .preamble(self.preamble.clone())
84 .messages(chat_history)
85 .temperature_opt(self.temperature)
86 .max_tokens_opt(self.max_tokens)
87 .additional_params_opt(self.additional_params.clone())
88 .documents(self.static_context.clone());
89
90 let agent = match &rag_text {
92 Some(text) => {
93 let dynamic_context = stream::iter(self.dynamic_context.iter())
94 .then(|(num_sample, index)| async {
95 Ok::<_, VectorStoreError>(
96 index
97 .top_n(text, *num_sample)
98 .await?
99 .into_iter()
100 .map(|(_, id, doc)| {
101 let text = serde_json::to_string_pretty(&doc)
103 .unwrap_or_else(|_| doc.to_string());
104
105 Document {
106 id,
107 text,
108 additional_props: HashMap::new(),
109 }
110 })
111 .collect::<Vec<_>>(),
112 )
113 })
114 .try_fold(vec![], |mut acc, docs| async {
115 acc.extend(docs);
116 Ok(acc)
117 })
118 .await
119 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
120
121 let dynamic_tools = stream::iter(self.dynamic_tools.iter())
122 .then(|(num_sample, index)| async {
123 Ok::<_, VectorStoreError>(
124 index
125 .top_n_ids(text, *num_sample)
126 .await?
127 .into_iter()
128 .map(|(_, id)| id)
129 .collect::<Vec<_>>(),
130 )
131 })
132 .try_fold(vec![], |mut acc, docs| async {
133 for doc in docs {
134 if let Some(tool) = self.tools.get(&doc) {
135 acc.push(tool.definition(text.into()).await)
136 } else {
137 tracing::warn!("Tool implementation not found in toolset: {}", doc);
138 }
139 }
140 Ok(acc)
141 })
142 .await
143 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
144
145 let static_tools = stream::iter(self.static_tools.iter())
146 .filter_map(|toolname| async move {
147 if let Some(tool) = self.tools.get(toolname) {
148 Some(tool.definition(text.into()).await)
149 } else {
150 tracing::warn!(
151 "Tool implementation not found in toolset: {}",
152 toolname
153 );
154 None
155 }
156 })
157 .collect::<Vec<_>>()
158 .await;
159
160 completion_request
161 .documents(dynamic_context)
162 .tools([static_tools.clone(), dynamic_tools].concat())
163 }
164 None => {
165 let static_tools = stream::iter(self.static_tools.iter())
166 .filter_map(|toolname| async move {
167 if let Some(tool) = self.tools.get(toolname) {
168 Some(tool.definition("".into()).await)
170 } else {
171 tracing::warn!(
172 "Tool implementation not found in toolset: {}",
173 toolname
174 );
175 None
176 }
177 })
178 .collect::<Vec<_>>()
179 .await;
180
181 completion_request.tools(static_tools)
182 }
183 };
184
185 Ok(agent)
186 }
187}
188
189#[allow(refining_impl_trait)]
197impl<M: CompletionModel> Prompt for Agent<M> {
198 fn prompt(&self, prompt: impl Into<Message> + Send) -> PromptRequest<M> {
199 PromptRequest::new(self, prompt)
200 }
201}
202
203#[allow(refining_impl_trait)]
204impl<M: CompletionModel> Prompt for &Agent<M> {
205 fn prompt(&self, prompt: impl Into<Message> + Send) -> PromptRequest<M> {
206 PromptRequest::new(*self, prompt)
207 }
208}
209
210#[allow(refining_impl_trait)]
211impl<M: CompletionModel> Chat for Agent<M> {
212 async fn chat(
213 &self,
214 prompt: impl Into<Message> + Send,
215 chat_history: Vec<Message>,
216 ) -> Result<String, PromptError> {
217 let mut cloned_history = chat_history.clone();
218 PromptRequest::new(self, prompt)
219 .with_history(&mut cloned_history)
220 .await
221 }
222}
223
224impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
225 async fn stream_completion(
226 &self,
227 prompt: impl Into<Message> + Send,
228 chat_history: Vec<Message>,
229 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
230 self.completion(prompt, chat_history).await
233 }
234}
235
236impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
237 async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
238 self.stream_chat(prompt, vec![]).await
239 }
240}
241
242impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
243 async fn stream_chat(
244 &self,
245 prompt: &str,
246 chat_history: Vec<Message>,
247 ) -> Result<StreamingResult, CompletionError> {
248 self.stream_completion(prompt, chat_history)
249 .await?
250 .stream()
251 .await
252 }
253}