Skip to main content

chat_core/chat/
mod.rs

1use std::collections::HashMap;
2
3use serde_json::Value;
4use tools_rs::ToolError;
5
6use crate::{
7    chat::state::Unstructured,
8    error::ChatError,
9    types::{
10        callback::{CallbackStrategy, RetryStrategy},
11        messages::{content::Content, tool::ToolStatus},
12        options::ChatOptions,
13        response::PauseReason,
14        tools::{Action, ToolDeclarations, TypedCollection},
15    },
16};
17
18pub mod completion;
19pub mod embed;
20pub mod state;
21#[cfg(feature = "stream")]
22pub mod stream;
23#[cfg(feature = "stream")]
24pub mod input_stream;
25
26#[derive(Default)]
27pub struct Chat<CP, Output = Unstructured> {
28    pub(crate) model: CP,
29    pub(crate) output_shape: Option<schemars::Schema>,
30    pub(crate) model_options: Option<ChatOptions>,
31    pub(crate) max_steps: Option<u16>,
32    pub(crate) max_retries: Option<u16>,
33    pub(crate) retry_strategy: Option<RetryStrategy>,
34    pub(crate) before_strategy: Option<CallbackStrategy>,
35    pub(crate) after_strategy: Option<CallbackStrategy>,
36    pub(crate) scoped_collections: Vec<Box<dyn TypedCollection>>,
37    pub(crate) routing: HashMap<String, usize>,
38    pub(crate) _output: std::marker::PhantomData<Output>,
39}
40
41/// Outcome of a single `tool_call` pass over the most recent Content.
42///
43/// - `executed`: at least one tool was run and populated (Completed/Failed/
44///   Rejected). Drives the continue-vs-stop decision in the chat loop.
45/// - `pause`: the strategy put at least one tool into a state that
46///   requires the caller to act. If `Some`, the loop returns
47///   `ChatOutcome::Paused` with this reason.
48#[derive(Debug, Default)]
49pub(crate) struct ToolCallPass {
50    pub executed: bool,
51    pub pause: Option<PauseReason>,
52}
53
54/// Aggregates every scoped collection's declarations into a single
55/// `ToolDeclarations` implementation that providers consume.
56///
57/// Borrowed against `&Chat` — lives only for the duration of a single
58/// request. Concatenation is lazy: `.json()` walks the collections on
59/// demand, so there's no per-request allocation beyond the resulting
60/// JSON array.
61pub(crate) struct AggregatedDeclarations<'a> {
62    collections: &'a [Box<dyn TypedCollection>],
63}
64
65impl<'a> ToolDeclarations for AggregatedDeclarations<'a> {
66    fn json(&self) -> Result<Value, ToolError> {
67        let mut all = Vec::new();
68        for coll in self.collections {
69            if let Value::Array(arr) = coll.declarations()? {
70                all.extend(arr);
71            }
72        }
73        Ok(Value::Array(all))
74    }
75}
76
77/// Build an `AggregatedDeclarations` view directly from a slice. Used
78/// by the chat loop so the borrow is scoped to `scoped_collections` and
79/// not the whole `Chat` — lets callers still borrow `self.model`
80/// mutably.
81pub(crate) fn tool_declarations_from(
82    collections: &[Box<dyn TypedCollection>],
83) -> Option<AggregatedDeclarations<'_>> {
84    if collections.is_empty() {
85        None
86    } else {
87        Some(AggregatedDeclarations { collections })
88    }
89}
90
91impl<P, Output> Chat<P, Output> {
92    /// Look up which scoped collection owns a given tool name.
93    fn collection_for(&self, name: &str) -> Option<&dyn TypedCollection> {
94        self.routing
95            .get(name)
96            .and_then(|&idx| self.scoped_collections.get(idx).map(|b| b.as_ref()))
97    }
98
99    /// Run one pass over the tools in `content`: consult each tool's
100    /// strategy, execute the ones that say `Execute`, leave the ones
101    /// that require human approval or deferral in a non-resolved state
102    /// and accumulate them into a `PauseReason`.
103    pub(crate) async fn tool_call(&self, content: &mut Content) -> Result<ToolCallPass, ChatError> {
104        let mut pass = ToolCallPass::default();
105
106        // Phase 1: strategy decisions + execution. Walk tools in order;
107        // every unresolved, non-approved tool gets its strategy consulted.
108        // Tools already in `Approved` state (typically from a resume
109        // after human review) skip straight to execution.
110        let mut idx = 0;
111        while idx < content.parts.0.len() {
112            let part = &mut content.parts.0[idx];
113            let tool = match part {
114                crate::types::messages::parts::PartEnum::Tool(t) => t,
115                _ => {
116                    idx += 1;
117                    continue;
118                }
119            };
120
121            if tool.is_resolved() {
122                idx += 1;
123                continue;
124            }
125
126            // Approved either by human review on resume, or programmatically.
127            // Skip strategy — execute directly.
128            let already_approved = matches!(tool.status, ToolStatus::Approved { .. });
129
130            let action = if already_approved {
131                Action::Execute
132            } else {
133                let coll = self.collection_for(&tool.call.name).ok_or_else(|| {
134                    ChatError::InvalidResponse(format!(
135                        "no scoped collection owns tool `{}`",
136                        tool.call.name
137                    ))
138                })?;
139                coll.decide(&tool.call)
140            };
141
142            match action {
143                Action::Execute => {
144                    // Run the tool via its owning collection.
145                    let coll = self.collection_for(&tool.call.name).ok_or_else(|| {
146                        ChatError::InvalidResponse(format!(
147                            "no scoped collection owns tool `{}`",
148                            tool.call.name
149                        ))
150                    })?;
151                    tool.mark_running();
152                    let call = tool.effective_call().clone();
153                    match coll.call(call).await {
154                        Ok(response) => tool.complete(response),
155                        Err(ToolError::Runtime(msg)) => tool.fail(msg),
156                        Err(e) => tool.fail(format!("{e:?}")),
157                    }
158                    pass.executed = true;
159                }
160                Action::RequireApproval => {
161                    // Leave Pending. Collect id for pause.
162                    match &mut pass.pause {
163                        None => {
164                            pass.pause = Some(PauseReason::AwaitingApproval {
165                                tool_ids: vec![tool.id.clone()],
166                            });
167                        }
168                        Some(PauseReason::AwaitingApproval { tool_ids }) => {
169                            tool_ids.push(tool.id.clone());
170                        }
171                        Some(PauseReason::Scheduled { .. }) => {
172                            let prev = std::mem::replace(
173                                &mut pass.pause,
174                                Some(PauseReason::AwaitingApproval { tool_ids: vec![] }),
175                            );
176                            if let Some(PauseReason::Scheduled {
177                                tool_ids: sch_ids,
178                                earliest,
179                            }) = prev
180                            {
181                                pass.pause = Some(PauseReason::Mixed {
182                                    approvals: vec![tool.id.clone()],
183                                    scheduled: sch_ids
184                                        .into_iter()
185                                        .map(|id| (id, earliest))
186                                        .collect(),
187                                });
188                            }
189                        }
190                        Some(PauseReason::Mixed { approvals, .. }) => {
191                            approvals.push(tool.id.clone());
192                        }
193                    }
194                }
195                Action::Defer { at } => match &mut pass.pause {
196                    None => {
197                        pass.pause = Some(PauseReason::Scheduled {
198                            tool_ids: vec![tool.id.clone()],
199                            earliest: at,
200                        });
201                    }
202                    Some(PauseReason::Scheduled { tool_ids, earliest }) => {
203                        tool_ids.push(tool.id.clone());
204                        if at < *earliest {
205                            *earliest = at;
206                        }
207                    }
208                    Some(PauseReason::AwaitingApproval { tool_ids }) => {
209                        let approvals = std::mem::take(tool_ids);
210                        pass.pause = Some(PauseReason::Mixed {
211                            approvals,
212                            scheduled: vec![(tool.id.clone(), at)],
213                        });
214                    }
215                    Some(PauseReason::Mixed { scheduled, .. }) => {
216                        scheduled.push((tool.id.clone(), at));
217                    }
218                },
219                Action::Reject { reason } => {
220                    tool.reject(Some(reason));
221                    pass.executed = true;
222                }
223            }
224
225            idx += 1;
226        }
227
228        Ok(pass)
229    }
230}