1use std::collections::HashMap;
2
3use serde_json::Value;
4use tools_rs::ToolError;
5
6use crate::{
7 chat::state::Unstructured,
8 error::ChatError,
9 types::{
10 callback::{CallbackStrategy, RetryStrategy},
11 messages::{content::Content, tool::ToolStatus},
12 options::ChatOptions,
13 response::PauseReason,
14 tools::{Action, ToolDeclarations, TypedCollection},
15 },
16};
17
18pub mod completion;
19pub mod embed;
20pub mod state;
21#[cfg(feature = "stream")]
22pub mod stream;
23
24#[derive(Default)]
25pub struct Chat<CP, Output = Unstructured> {
26 pub(crate) model: CP,
27 pub(crate) output_shape: Option<schemars::Schema>,
28 pub(crate) model_options: Option<ChatOptions>,
29 pub(crate) max_steps: Option<u16>,
30 pub(crate) max_retries: Option<u16>,
31 pub(crate) retry_strategy: Option<RetryStrategy>,
32 pub(crate) before_strategy: Option<CallbackStrategy>,
33 pub(crate) after_strategy: Option<CallbackStrategy>,
34 pub(crate) scoped_collections: Vec<Box<dyn TypedCollection>>,
35 pub(crate) routing: HashMap<String, usize>,
36 pub(crate) _output: std::marker::PhantomData<Output>,
37}
38
39#[derive(Debug, Default)]
47pub(crate) struct ToolCallPass {
48 pub executed: bool,
49 pub pause: Option<PauseReason>,
50}
51
52pub(crate) struct AggregatedDeclarations<'a> {
60 collections: &'a [Box<dyn TypedCollection>],
61}
62
63impl<'a> ToolDeclarations for AggregatedDeclarations<'a> {
64 fn json(&self) -> Result<Value, ToolError> {
65 let mut all = Vec::new();
66 for coll in self.collections {
67 if let Value::Array(arr) = coll.declarations()? {
68 all.extend(arr);
69 }
70 }
71 Ok(Value::Array(all))
72 }
73}
74
75pub(crate) fn tool_declarations_from(
80 collections: &[Box<dyn TypedCollection>],
81) -> Option<AggregatedDeclarations<'_>> {
82 if collections.is_empty() {
83 None
84 } else {
85 Some(AggregatedDeclarations { collections })
86 }
87}
88
89impl<P, Output> Chat<P, Output> {
90
91 fn collection_for(&self, name: &str) -> Option<&dyn TypedCollection> {
93 self.routing
94 .get(name)
95 .and_then(|&idx| self.scoped_collections.get(idx).map(|b| b.as_ref()))
96 }
97
98 pub(crate) async fn tool_call(
103 &self,
104 content: &mut Content,
105 ) -> Result<ToolCallPass, ChatError> {
106 let mut pass = ToolCallPass::default();
107
108 let mut idx = 0;
113 while idx < content.parts.0.len() {
114 let part = &mut content.parts.0[idx];
115 let tool = match part {
116 crate::types::messages::parts::PartEnum::Tool(t) => t,
117 _ => {
118 idx += 1;
119 continue;
120 }
121 };
122
123 if tool.is_resolved() {
124 idx += 1;
125 continue;
126 }
127
128 let already_approved = matches!(tool.status, ToolStatus::Approved { .. });
131
132 let action = if already_approved {
133 Action::Execute
134 } else {
135 let coll = self.collection_for(&tool.call.name).ok_or_else(|| {
136 ChatError::InvalidResponse(format!(
137 "no scoped collection owns tool `{}`",
138 tool.call.name
139 ))
140 })?;
141 coll.decide(&tool.call)
142 };
143
144 match action {
145 Action::Execute => {
146 let coll = self.collection_for(&tool.call.name).ok_or_else(|| {
148 ChatError::InvalidResponse(format!(
149 "no scoped collection owns tool `{}`",
150 tool.call.name
151 ))
152 })?;
153 tool.mark_running();
154 let call = tool.effective_call().clone();
155 match coll.call(call).await {
156 Ok(response) => tool.complete(response),
157 Err(ToolError::Runtime(msg)) => tool.fail(msg),
158 Err(e) => tool.fail(format!("{e:?}")),
159 }
160 pass.executed = true;
161 }
162 Action::RequireApproval => {
163 match &mut pass.pause {
165 None => {
166 pass.pause = Some(PauseReason::AwaitingApproval {
167 tool_ids: vec![tool.id.clone()],
168 });
169 }
170 Some(PauseReason::AwaitingApproval { tool_ids }) => {
171 tool_ids.push(tool.id.clone());
172 }
173 Some(PauseReason::Scheduled { .. }) => {
174 let prev = std::mem::replace(
175 &mut pass.pause,
176 Some(PauseReason::AwaitingApproval { tool_ids: vec![] }),
177 );
178 if let Some(PauseReason::Scheduled {
179 tool_ids: sch_ids,
180 earliest,
181 }) = prev
182 {
183 pass.pause = Some(PauseReason::Mixed {
184 approvals: vec![tool.id.clone()],
185 scheduled: sch_ids
186 .into_iter()
187 .map(|id| (id, earliest))
188 .collect(),
189 });
190 }
191 }
192 Some(PauseReason::Mixed { approvals, .. }) => {
193 approvals.push(tool.id.clone());
194 }
195 }
196 }
197 Action::Defer { at } => {
198 match &mut pass.pause {
199 None => {
200 pass.pause = Some(PauseReason::Scheduled {
201 tool_ids: vec![tool.id.clone()],
202 earliest: at,
203 });
204 }
205 Some(PauseReason::Scheduled {
206 tool_ids,
207 earliest,
208 }) => {
209 tool_ids.push(tool.id.clone());
210 if at < *earliest {
211 *earliest = at;
212 }
213 }
214 Some(PauseReason::AwaitingApproval { tool_ids }) => {
215 let approvals = std::mem::take(tool_ids);
216 pass.pause = Some(PauseReason::Mixed {
217 approvals,
218 scheduled: vec![(tool.id.clone(), at)],
219 });
220 }
221 Some(PauseReason::Mixed { scheduled, .. }) => {
222 scheduled.push((tool.id.clone(), at));
223 }
224 }
225 }
226 Action::Reject { reason } => {
227 tool.reject(Some(reason));
228 pass.executed = true;
229 }
230 }
231
232 idx += 1;
233 }
234
235 Ok(pass)
236 }
237}