Skip to main content

chat_core/chat/
mod.rs

1use std::collections::HashMap;
2
3use serde_json::Value;
4use tools_rs::ToolError;
5
6use crate::{
7    chat::state::Unstructured,
8    error::ChatError,
9    types::{
10        callback::{CallbackStrategy, RetryStrategy},
11        messages::{content::Content, tool::ToolStatus},
12        options::ChatOptions,
13        response::PauseReason,
14        tools::{Action, ToolDeclarations, TypedCollection},
15    },
16};
17
18pub mod completion;
19pub mod embed;
20pub mod state;
21#[cfg(feature = "stream")]
22pub mod stream;
23
24#[derive(Default)]
25pub struct Chat<CP, Output = Unstructured> {
26    pub(crate) model: CP,
27    pub(crate) output_shape: Option<schemars::Schema>,
28    pub(crate) model_options: Option<ChatOptions>,
29    pub(crate) max_steps: Option<u16>,
30    pub(crate) max_retries: Option<u16>,
31    pub(crate) retry_strategy: Option<RetryStrategy>,
32    pub(crate) before_strategy: Option<CallbackStrategy>,
33    pub(crate) after_strategy: Option<CallbackStrategy>,
34    pub(crate) scoped_collections: Vec<Box<dyn TypedCollection>>,
35    pub(crate) routing: HashMap<String, usize>,
36    pub(crate) _output: std::marker::PhantomData<Output>,
37}
38
39/// Outcome of a single `tool_call` pass over the most recent Content.
40///
41/// - `executed`: at least one tool was run and populated (Completed/Failed/
42///   Rejected). Drives the continue-vs-stop decision in the chat loop.
43/// - `pause`: the strategy put at least one tool into a state that
44///   requires the caller to act. If `Some`, the loop returns
45///   `ChatOutcome::Paused` with this reason.
46#[derive(Debug, Default)]
47pub(crate) struct ToolCallPass {
48    pub executed: bool,
49    pub pause: Option<PauseReason>,
50}
51
52/// Aggregates every scoped collection's declarations into a single
53/// `ToolDeclarations` implementation that providers consume.
54///
55/// Borrowed against `&Chat` — lives only for the duration of a single
56/// request. Concatenation is lazy: `.json()` walks the collections on
57/// demand, so there's no per-request allocation beyond the resulting
58/// JSON array.
59pub(crate) struct AggregatedDeclarations<'a> {
60    collections: &'a [Box<dyn TypedCollection>],
61}
62
63impl<'a> ToolDeclarations for AggregatedDeclarations<'a> {
64    fn json(&self) -> Result<Value, ToolError> {
65        let mut all = Vec::new();
66        for coll in self.collections {
67            if let Value::Array(arr) = coll.declarations()? {
68                all.extend(arr);
69            }
70        }
71        Ok(Value::Array(all))
72    }
73}
74
75/// Build an `AggregatedDeclarations` view directly from a slice. Used
76/// by the chat loop so the borrow is scoped to `scoped_collections` and
77/// not the whole `Chat` — lets callers still borrow `self.model`
78/// mutably.
79pub(crate) fn tool_declarations_from(
80    collections: &[Box<dyn TypedCollection>],
81) -> Option<AggregatedDeclarations<'_>> {
82    if collections.is_empty() {
83        None
84    } else {
85        Some(AggregatedDeclarations { collections })
86    }
87}
88
89impl<P, Output> Chat<P, Output> {
90
91    /// Look up which scoped collection owns a given tool name.
92    fn collection_for(&self, name: &str) -> Option<&dyn TypedCollection> {
93        self.routing
94            .get(name)
95            .and_then(|&idx| self.scoped_collections.get(idx).map(|b| b.as_ref()))
96    }
97
98    /// Run one pass over the tools in `content`: consult each tool's
99    /// strategy, execute the ones that say `Execute`, leave the ones
100    /// that require human approval or deferral in a non-resolved state
101    /// and accumulate them into a `PauseReason`.
102    pub(crate) async fn tool_call(
103        &self,
104        content: &mut Content,
105    ) -> Result<ToolCallPass, ChatError> {
106        let mut pass = ToolCallPass::default();
107
108        // Phase 1: strategy decisions + execution. Walk tools in order;
109        // every unresolved, non-approved tool gets its strategy consulted.
110        // Tools already in `Approved` state (typically from a resume
111        // after human review) skip straight to execution.
112        let mut idx = 0;
113        while idx < content.parts.0.len() {
114            let part = &mut content.parts.0[idx];
115            let tool = match part {
116                crate::types::messages::parts::PartEnum::Tool(t) => t,
117                _ => {
118                    idx += 1;
119                    continue;
120                }
121            };
122
123            if tool.is_resolved() {
124                idx += 1;
125                continue;
126            }
127
128            // Approved either by human review on resume, or programmatically.
129            // Skip strategy — execute directly.
130            let already_approved = matches!(tool.status, ToolStatus::Approved { .. });
131
132            let action = if already_approved {
133                Action::Execute
134            } else {
135                let coll = self.collection_for(&tool.call.name).ok_or_else(|| {
136                    ChatError::InvalidResponse(format!(
137                        "no scoped collection owns tool `{}`",
138                        tool.call.name
139                    ))
140                })?;
141                coll.decide(&tool.call)
142            };
143
144            match action {
145                Action::Execute => {
146                    // Run the tool via its owning collection.
147                    let coll = self.collection_for(&tool.call.name).ok_or_else(|| {
148                        ChatError::InvalidResponse(format!(
149                            "no scoped collection owns tool `{}`",
150                            tool.call.name
151                        ))
152                    })?;
153                    tool.mark_running();
154                    let call = tool.effective_call().clone();
155                    match coll.call(call).await {
156                        Ok(response) => tool.complete(response),
157                        Err(ToolError::Runtime(msg)) => tool.fail(msg),
158                        Err(e) => tool.fail(format!("{e:?}")),
159                    }
160                    pass.executed = true;
161                }
162                Action::RequireApproval => {
163                    // Leave Pending. Collect id for pause.
164                    match &mut pass.pause {
165                        None => {
166                            pass.pause = Some(PauseReason::AwaitingApproval {
167                                tool_ids: vec![tool.id.clone()],
168                            });
169                        }
170                        Some(PauseReason::AwaitingApproval { tool_ids }) => {
171                            tool_ids.push(tool.id.clone());
172                        }
173                        Some(PauseReason::Scheduled { .. }) => {
174                            let prev = std::mem::replace(
175                                &mut pass.pause,
176                                Some(PauseReason::AwaitingApproval { tool_ids: vec![] }),
177                            );
178                            if let Some(PauseReason::Scheduled {
179                                tool_ids: sch_ids,
180                                earliest,
181                            }) = prev
182                            {
183                                pass.pause = Some(PauseReason::Mixed {
184                                    approvals: vec![tool.id.clone()],
185                                    scheduled: sch_ids
186                                        .into_iter()
187                                        .map(|id| (id, earliest))
188                                        .collect(),
189                                });
190                            }
191                        }
192                        Some(PauseReason::Mixed { approvals, .. }) => {
193                            approvals.push(tool.id.clone());
194                        }
195                    }
196                }
197                Action::Defer { at } => {
198                    match &mut pass.pause {
199                        None => {
200                            pass.pause = Some(PauseReason::Scheduled {
201                                tool_ids: vec![tool.id.clone()],
202                                earliest: at,
203                            });
204                        }
205                        Some(PauseReason::Scheduled {
206                            tool_ids,
207                            earliest,
208                        }) => {
209                            tool_ids.push(tool.id.clone());
210                            if at < *earliest {
211                                *earliest = at;
212                            }
213                        }
214                        Some(PauseReason::AwaitingApproval { tool_ids }) => {
215                            let approvals = std::mem::take(tool_ids);
216                            pass.pause = Some(PauseReason::Mixed {
217                                approvals,
218                                scheduled: vec![(tool.id.clone(), at)],
219                            });
220                        }
221                        Some(PauseReason::Mixed { scheduled, .. }) => {
222                            scheduled.push((tool.id.clone(), at));
223                        }
224                    }
225                }
226                Action::Reject { reason } => {
227                    tool.reject(Some(reason));
228                    pass.executed = true;
229                }
230            }
231
232            idx += 1;
233        }
234
235        Ok(pass)
236    }
237}