1use std::collections::HashMap;
2
3use serde_json::Value;
4use tools_rs::ToolError;
5
6use crate::{
7 chat::state::Unstructured,
8 error::ChatError,
9 types::{
10 callback::{CallbackStrategy, RetryStrategy},
11 messages::{content::Content, tool::ToolStatus},
12 options::ChatOptions,
13 response::PauseReason,
14 tools::{Action, ToolDeclarations, TypedCollection},
15 },
16};
17
18pub mod completion;
19pub mod embed;
20pub mod state;
21#[cfg(feature = "stream")]
22pub mod stream;
23#[cfg(feature = "stream")]
24pub mod input_stream;
25
26#[derive(Default)]
27pub struct Chat<CP, Output = Unstructured> {
28 pub(crate) model: CP,
29 pub(crate) output_shape: Option<schemars::Schema>,
30 pub(crate) model_options: Option<ChatOptions>,
31 pub(crate) max_steps: Option<u16>,
32 pub(crate) max_retries: Option<u16>,
33 pub(crate) retry_strategy: Option<RetryStrategy>,
34 pub(crate) before_strategy: Option<CallbackStrategy>,
35 pub(crate) after_strategy: Option<CallbackStrategy>,
36 pub(crate) scoped_collections: Vec<Box<dyn TypedCollection>>,
37 pub(crate) routing: HashMap<String, usize>,
38 pub(crate) _output: std::marker::PhantomData<Output>,
39}
40
41#[derive(Debug, Default)]
49pub(crate) struct ToolCallPass {
50 pub executed: bool,
51 pub pause: Option<PauseReason>,
52}
53
54pub(crate) struct AggregatedDeclarations<'a> {
62 collections: &'a [Box<dyn TypedCollection>],
63}
64
65impl<'a> ToolDeclarations for AggregatedDeclarations<'a> {
66 fn json(&self) -> Result<Value, ToolError> {
67 let mut all = Vec::new();
68 for coll in self.collections {
69 if let Value::Array(arr) = coll.declarations()? {
70 all.extend(arr);
71 }
72 }
73 Ok(Value::Array(all))
74 }
75}
76
77pub(crate) fn tool_declarations_from(
82 collections: &[Box<dyn TypedCollection>],
83) -> Option<AggregatedDeclarations<'_>> {
84 if collections.is_empty() {
85 None
86 } else {
87 Some(AggregatedDeclarations { collections })
88 }
89}
90
91impl<P, Output> Chat<P, Output> {
92 fn collection_for(&self, name: &str) -> Option<&dyn TypedCollection> {
94 self.routing
95 .get(name)
96 .and_then(|&idx| self.scoped_collections.get(idx).map(|b| b.as_ref()))
97 }
98
99 pub(crate) async fn tool_call(&self, content: &mut Content) -> Result<ToolCallPass, ChatError> {
104 let mut pass = ToolCallPass::default();
105
106 let mut idx = 0;
111 while idx < content.parts.0.len() {
112 let part = &mut content.parts.0[idx];
113 let tool = match part {
114 crate::types::messages::parts::PartEnum::Tool(t) => t,
115 _ => {
116 idx += 1;
117 continue;
118 }
119 };
120
121 if tool.is_resolved() {
122 idx += 1;
123 continue;
124 }
125
126 let already_approved = matches!(tool.status, ToolStatus::Approved { .. });
129
130 let action = if already_approved {
131 Action::Execute
132 } else {
133 let coll = self.collection_for(&tool.call.name).ok_or_else(|| {
134 ChatError::InvalidResponse(format!(
135 "no scoped collection owns tool `{}`",
136 tool.call.name
137 ))
138 })?;
139 coll.decide(&tool.call)
140 };
141
142 match action {
143 Action::Execute => {
144 let coll = self.collection_for(&tool.call.name).ok_or_else(|| {
146 ChatError::InvalidResponse(format!(
147 "no scoped collection owns tool `{}`",
148 tool.call.name
149 ))
150 })?;
151 tool.mark_running();
152 let call = tool.effective_call().clone();
153 match coll.call(call).await {
154 Ok(response) => tool.complete(response),
155 Err(ToolError::Runtime(msg)) => tool.fail(msg),
156 Err(e) => tool.fail(format!("{e:?}")),
157 }
158 pass.executed = true;
159 }
160 Action::RequireApproval => {
161 match &mut pass.pause {
163 None => {
164 pass.pause = Some(PauseReason::AwaitingApproval {
165 tool_ids: vec![tool.id.clone()],
166 });
167 }
168 Some(PauseReason::AwaitingApproval { tool_ids }) => {
169 tool_ids.push(tool.id.clone());
170 }
171 Some(PauseReason::Scheduled { .. }) => {
172 let prev = std::mem::replace(
173 &mut pass.pause,
174 Some(PauseReason::AwaitingApproval { tool_ids: vec![] }),
175 );
176 if let Some(PauseReason::Scheduled {
177 tool_ids: sch_ids,
178 earliest,
179 }) = prev
180 {
181 pass.pause = Some(PauseReason::Mixed {
182 approvals: vec![tool.id.clone()],
183 scheduled: sch_ids
184 .into_iter()
185 .map(|id| (id, earliest))
186 .collect(),
187 });
188 }
189 }
190 Some(PauseReason::Mixed { approvals, .. }) => {
191 approvals.push(tool.id.clone());
192 }
193 }
194 }
195 Action::Defer { at } => match &mut pass.pause {
196 None => {
197 pass.pause = Some(PauseReason::Scheduled {
198 tool_ids: vec![tool.id.clone()],
199 earliest: at,
200 });
201 }
202 Some(PauseReason::Scheduled { tool_ids, earliest }) => {
203 tool_ids.push(tool.id.clone());
204 if at < *earliest {
205 *earliest = at;
206 }
207 }
208 Some(PauseReason::AwaitingApproval { tool_ids }) => {
209 let approvals = std::mem::take(tool_ids);
210 pass.pause = Some(PauseReason::Mixed {
211 approvals,
212 scheduled: vec![(tool.id.clone(), at)],
213 });
214 }
215 Some(PauseReason::Mixed { scheduled, .. }) => {
216 scheduled.push((tool.id.clone(), at));
217 }
218 },
219 Action::Reject { reason } => {
220 tool.reject(Some(reason));
221 pass.executed = true;
222 }
223 }
224
225 idx += 1;
226 }
227
228 Ok(pass)
229 }
230}