1use std::collections::HashMap;
2
3use serde_json::Value;
4use tools_rs::ToolError;
5
6use crate::{
7 chat::state::Unstructured,
8 error::ChatError,
9 types::{
10 callback::{CallbackStrategy, RetryStrategy},
11 messages::{content::Content, tool::ToolStatus},
12 options::ChatOptions,
13 response::PauseReason,
14 tools::{Action, ToolDeclarations, TypedCollection},
15 },
16};
17
18pub mod completion;
19pub mod embed;
20pub mod state;
21#[cfg(feature = "stream")]
22pub mod stream;
23
24#[derive(Default)]
25pub struct Chat<CP, Output = Unstructured> {
26 pub(crate) model: CP,
27 pub(crate) output_shape: Option<schemars::Schema>,
28 pub(crate) model_options: Option<ChatOptions>,
29 pub(crate) max_steps: Option<u16>,
30 pub(crate) max_retries: Option<u16>,
31 pub(crate) retry_strategy: Option<RetryStrategy>,
32 pub(crate) before_strategy: Option<CallbackStrategy>,
33 pub(crate) after_strategy: Option<CallbackStrategy>,
34 pub(crate) scoped_collections: Vec<Box<dyn TypedCollection>>,
35 pub(crate) routing: HashMap<String, usize>,
36 pub(crate) _output: std::marker::PhantomData<Output>,
37}
38
39#[derive(Debug, Default)]
47pub(crate) struct ToolCallPass {
48 pub executed: bool,
49 pub pause: Option<PauseReason>,
50}
51
52pub(crate) struct AggregatedDeclarations<'a> {
60 collections: &'a [Box<dyn TypedCollection>],
61}
62
63impl<'a> ToolDeclarations for AggregatedDeclarations<'a> {
64 fn json(&self) -> Result<Value, ToolError> {
65 let mut all = Vec::new();
66 for coll in self.collections {
67 if let Value::Array(arr) = coll.declarations()? {
68 all.extend(arr);
69 }
70 }
71 Ok(Value::Array(all))
72 }
73}
74
75pub(crate) fn tool_declarations_from(
80 collections: &[Box<dyn TypedCollection>],
81) -> Option<AggregatedDeclarations<'_>> {
82 if collections.is_empty() {
83 None
84 } else {
85 Some(AggregatedDeclarations { collections })
86 }
87}
88
89impl<P, Output> Chat<P, Output> {
90 fn collection_for(&self, name: &str) -> Option<&dyn TypedCollection> {
92 self.routing
93 .get(name)
94 .and_then(|&idx| self.scoped_collections.get(idx).map(|b| b.as_ref()))
95 }
96
97 pub(crate) async fn tool_call(&self, content: &mut Content) -> Result<ToolCallPass, ChatError> {
102 let mut pass = ToolCallPass::default();
103
104 let mut idx = 0;
109 while idx < content.parts.0.len() {
110 let part = &mut content.parts.0[idx];
111 let tool = match part {
112 crate::types::messages::parts::PartEnum::Tool(t) => t,
113 _ => {
114 idx += 1;
115 continue;
116 }
117 };
118
119 if tool.is_resolved() {
120 idx += 1;
121 continue;
122 }
123
124 let already_approved = matches!(tool.status, ToolStatus::Approved { .. });
127
128 let action = if already_approved {
129 Action::Execute
130 } else {
131 let coll = self.collection_for(&tool.call.name).ok_or_else(|| {
132 ChatError::InvalidResponse(format!(
133 "no scoped collection owns tool `{}`",
134 tool.call.name
135 ))
136 })?;
137 coll.decide(&tool.call)
138 };
139
140 match action {
141 Action::Execute => {
142 let coll = self.collection_for(&tool.call.name).ok_or_else(|| {
144 ChatError::InvalidResponse(format!(
145 "no scoped collection owns tool `{}`",
146 tool.call.name
147 ))
148 })?;
149 tool.mark_running();
150 let call = tool.effective_call().clone();
151 match coll.call(call).await {
152 Ok(response) => tool.complete(response),
153 Err(ToolError::Runtime(msg)) => tool.fail(msg),
154 Err(e) => tool.fail(format!("{e:?}")),
155 }
156 pass.executed = true;
157 }
158 Action::RequireApproval => {
159 match &mut pass.pause {
161 None => {
162 pass.pause = Some(PauseReason::AwaitingApproval {
163 tool_ids: vec![tool.id.clone()],
164 });
165 }
166 Some(PauseReason::AwaitingApproval { tool_ids }) => {
167 tool_ids.push(tool.id.clone());
168 }
169 Some(PauseReason::Scheduled { .. }) => {
170 let prev = std::mem::replace(
171 &mut pass.pause,
172 Some(PauseReason::AwaitingApproval { tool_ids: vec![] }),
173 );
174 if let Some(PauseReason::Scheduled {
175 tool_ids: sch_ids,
176 earliest,
177 }) = prev
178 {
179 pass.pause = Some(PauseReason::Mixed {
180 approvals: vec![tool.id.clone()],
181 scheduled: sch_ids
182 .into_iter()
183 .map(|id| (id, earliest))
184 .collect(),
185 });
186 }
187 }
188 Some(PauseReason::Mixed { approvals, .. }) => {
189 approvals.push(tool.id.clone());
190 }
191 }
192 }
193 Action::Defer { at } => match &mut pass.pause {
194 None => {
195 pass.pause = Some(PauseReason::Scheduled {
196 tool_ids: vec![tool.id.clone()],
197 earliest: at,
198 });
199 }
200 Some(PauseReason::Scheduled { tool_ids, earliest }) => {
201 tool_ids.push(tool.id.clone());
202 if at < *earliest {
203 *earliest = at;
204 }
205 }
206 Some(PauseReason::AwaitingApproval { tool_ids }) => {
207 let approvals = std::mem::take(tool_ids);
208 pass.pause = Some(PauseReason::Mixed {
209 approvals,
210 scheduled: vec![(tool.id.clone(), at)],
211 });
212 }
213 Some(PauseReason::Mixed { scheduled, .. }) => {
214 scheduled.push((tool.id.clone(), at));
215 }
216 },
217 Action::Reject { reason } => {
218 tool.reject(Some(reason));
219 pass.executed = true;
220 }
221 }
222
223 idx += 1;
224 }
225
226 Ok(pass)
227 }
228}