1use super::prompt_request::{self, PromptRequest};
2use crate::{
3 completion::{
4 Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
5 Message, Prompt, PromptError,
6 },
7 streaming::{StreamingChat, StreamingCompletion, StreamingCompletionResponse, StreamingPrompt},
8 tool::ToolSet,
9 vector_store::VectorStoreError,
10};
11use futures::{StreamExt, TryStreamExt, stream};
12use std::collections::HashMap;
13
14pub struct Agent<M: CompletionModel> {
35 pub model: M,
37 pub preamble: String,
39 pub static_context: Vec<Document>,
41 pub static_tools: Vec<String>,
43 pub temperature: Option<f64>,
45 pub max_tokens: Option<u64>,
47 pub additional_params: Option<serde_json::Value>,
49 pub dynamic_context: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
51 pub dynamic_tools: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
53 pub tools: ToolSet,
55}
56
57impl<M: CompletionModel> Completion<M> for Agent<M> {
58 async fn completion(
59 &self,
60 prompt: impl Into<Message> + Send,
61 chat_history: Vec<Message>,
62 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
63 let prompt = prompt.into();
64
65 let rag_text = prompt.rag_text();
67 let rag_text = rag_text.or_else(|| {
68 chat_history
69 .iter()
70 .rev()
71 .find_map(|message| message.rag_text())
72 });
73
74 let completion_request = self
75 .model
76 .completion_request(prompt)
77 .preamble(self.preamble.clone())
78 .messages(chat_history)
79 .temperature_opt(self.temperature)
80 .max_tokens_opt(self.max_tokens)
81 .additional_params_opt(self.additional_params.clone())
82 .documents(self.static_context.clone());
83
84 let agent = match &rag_text {
86 Some(text) => {
87 let dynamic_context = stream::iter(self.dynamic_context.iter())
88 .then(|(num_sample, index)| async {
89 Ok::<_, VectorStoreError>(
90 index
91 .top_n(text, *num_sample)
92 .await?
93 .into_iter()
94 .map(|(_, id, doc)| {
95 let text = serde_json::to_string_pretty(&doc)
97 .unwrap_or_else(|_| doc.to_string());
98
99 Document {
100 id,
101 text,
102 additional_props: HashMap::new(),
103 }
104 })
105 .collect::<Vec<_>>(),
106 )
107 })
108 .try_fold(vec![], |mut acc, docs| async {
109 acc.extend(docs);
110 Ok(acc)
111 })
112 .await
113 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
114
115 let dynamic_tools = stream::iter(self.dynamic_tools.iter())
116 .then(|(num_sample, index)| async {
117 Ok::<_, VectorStoreError>(
118 index
119 .top_n_ids(text, *num_sample)
120 .await?
121 .into_iter()
122 .map(|(_, id)| id)
123 .collect::<Vec<_>>(),
124 )
125 })
126 .try_fold(vec![], |mut acc, docs| async {
127 for doc in docs {
128 if let Some(tool) = self.tools.get(&doc) {
129 acc.push(tool.definition(text.into()).await)
130 } else {
131 tracing::warn!("Tool implementation not found in toolset: {}", doc);
132 }
133 }
134 Ok(acc)
135 })
136 .await
137 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
138
139 let static_tools = stream::iter(self.static_tools.iter())
140 .filter_map(|toolname| async move {
141 if let Some(tool) = self.tools.get(toolname) {
142 Some(tool.definition(text.into()).await)
143 } else {
144 tracing::warn!(
145 "Tool implementation not found in toolset: {}",
146 toolname
147 );
148 None
149 }
150 })
151 .collect::<Vec<_>>()
152 .await;
153
154 completion_request
155 .documents(dynamic_context)
156 .tools([static_tools.clone(), dynamic_tools].concat())
157 }
158 None => {
159 let static_tools = stream::iter(self.static_tools.iter())
160 .filter_map(|toolname| async move {
161 if let Some(tool) = self.tools.get(toolname) {
162 Some(tool.definition("".into()).await)
164 } else {
165 tracing::warn!(
166 "Tool implementation not found in toolset: {}",
167 toolname
168 );
169 None
170 }
171 })
172 .collect::<Vec<_>>()
173 .await;
174
175 completion_request.tools(static_tools)
176 }
177 };
178
179 Ok(agent)
180 }
181}
182
183#[allow(refining_impl_trait)]
191impl<M: CompletionModel> Prompt for Agent<M> {
192 fn prompt(
193 &self,
194 prompt: impl Into<Message> + Send,
195 ) -> PromptRequest<prompt_request::Standard, M> {
196 PromptRequest::new(self, prompt)
197 }
198}
199
200#[allow(refining_impl_trait)]
201impl<M: CompletionModel> Prompt for &Agent<M> {
202 fn prompt(
203 &self,
204 prompt: impl Into<Message> + Send,
205 ) -> PromptRequest<prompt_request::Standard, M> {
206 PromptRequest::new(*self, prompt)
207 }
208}
209
210#[allow(refining_impl_trait)]
211impl<M: CompletionModel> Chat for Agent<M> {
212 async fn chat(
213 &self,
214 prompt: impl Into<Message> + Send,
215 chat_history: Vec<Message>,
216 ) -> Result<String, PromptError> {
217 let mut cloned_history = chat_history.clone();
218 PromptRequest::new(self, prompt)
219 .with_history(&mut cloned_history)
220 .await
221 }
222}
223
224impl<M: CompletionModel> StreamingCompletion<M> for Agent<M> {
225 async fn stream_completion(
226 &self,
227 prompt: impl Into<Message> + Send,
228 chat_history: Vec<Message>,
229 ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
230 self.completion(prompt, chat_history).await
233 }
234}
235
236impl<M: CompletionModel> StreamingPrompt<M::StreamingResponse> for Agent<M> {
237 async fn stream_prompt(
238 &self,
239 prompt: impl Into<Message> + Send,
240 ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
241 self.stream_chat(prompt, vec![]).await
242 }
243}
244
245impl<M: CompletionModel> StreamingChat<M::StreamingResponse> for Agent<M> {
246 async fn stream_chat(
247 &self,
248 prompt: impl Into<Message> + Send,
249 chat_history: Vec<Message>,
250 ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
251 self.stream_completion(prompt, chat_history)
252 .await?
253 .stream()
254 .await
255 }
256}