Skip to main content

crabtalk_core/
model.rs

1//! Unified LLM interface types and the `Model<P>` wrapper.
2//!
3//! Thin re-export layer over `crabllm_core` for the core wire types
4//! (`Message`, `Tool`, `ToolCall`, `Usage`, …) plus crabtalk's own
5//! `HistoryEntry` wrapper and streaming `MessageBuilder`. `Model<P>` is the
6//! single seam between crabtalk and any `crabllm_core::Provider`.
7
8pub use crabllm_core::{
9    ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, CompletionTokensDetails,
10    FinishReason, FunctionCall, FunctionDef, Message, Role, Tool, ToolCall, ToolCallDelta,
11    ToolChoice, ToolType, Usage,
12};
13
14use anyhow::Result;
15use async_stream::try_stream;
16use crabllm_core::{ApiError, Provider};
17use futures_core::Stream;
18use futures_util::StreamExt;
19use serde::{Deserialize, Serialize};
20use std::{collections::BTreeMap, sync::Arc};
21
22// ── HistoryEntry ────────────────────────────────────────────────────
23
24/// A single conversation history entry.
25///
26/// The inner `message` is the wire-level shape sent to providers. The
27/// runtime-only fields are stripped from the wire but persisted to the
28/// session `Storage` for reload (except `sender` and `auto_injected`,
29/// which are session-local state that resets on reload).
30#[derive(Debug, Clone, Deserialize, Serialize)]
31pub struct HistoryEntry {
32    /// Which agent produced this assistant message. Empty = the conversation's
33    /// primary agent. Non-empty = a guest agent pulled in via an @ mention
34    /// or guest turn. Persisted so reloads can reconstruct multi-agent state.
35    #[serde(default, skip_serializing_if = "String::is_empty")]
36    pub agent: String,
37
38    /// The sender identity (runtime-only, never serialized).
39    #[serde(skip)]
40    pub sender: String,
41
42    /// Whether this entry was auto-injected by the runtime (runtime-only).
43    /// Auto-injected entries are stripped before each new run and never
44    /// persisted as session steps.
45    #[serde(skip)]
46    pub auto_injected: bool,
47
48    /// The wire-level message sent to providers.
49    pub message: Message,
50}
51
52impl HistoryEntry {
53    /// Create a new system entry.
54    pub fn system(content: impl Into<String>) -> Self {
55        Self::from_message(Message::system(content))
56    }
57
58    /// Create a new user entry.
59    pub fn user(content: impl Into<String>) -> Self {
60        Self::from_message(Message::user(content))
61    }
62
63    /// Create a new user entry with sender identity.
64    pub fn user_with_sender(content: impl Into<String>, sender: impl Into<String>) -> Self {
65        let mut entry = Self::user(content);
66        entry.sender = sender.into();
67        entry
68    }
69
70    /// Create a new assistant entry.
71    ///
72    /// Preserves the `content: null` vs empty-string discrimination:
73    /// - assistant + non-empty `tool_calls` + empty content → `"content": null`
74    /// - assistant + empty `tool_calls` + empty content → `"content": ""`
75    /// - anything else → `"content": "<the text>"`
76    pub fn assistant(
77        content: impl Into<String>,
78        reasoning: Option<String>,
79        tool_calls: Option<&[ToolCall]>,
80    ) -> Self {
81        let content: String = content.into();
82        let has_tool_calls = tool_calls.is_some_and(|tcs| !tcs.is_empty());
83        let message_content = if content.is_empty() && has_tool_calls {
84            Some(serde_json::Value::Null)
85        } else {
86            Some(serde_json::Value::String(content))
87        };
88        Self::from_message(Message {
89            role: Role::Assistant,
90            content: message_content,
91            tool_calls: tool_calls.map(|tcs| tcs.to_vec()),
92            tool_call_id: None,
93            name: None,
94            reasoning_content: reasoning.filter(|s| !s.is_empty()),
95            extra: Default::default(),
96        })
97    }
98
99    /// Create a new tool-result entry.
100    pub fn tool(
101        content: impl Into<String>,
102        call_id: impl Into<String>,
103        name: impl Into<String>,
104    ) -> Self {
105        Self::from_message(Message::tool(call_id, name, content))
106    }
107
108    /// Wrap an existing `crabllm_core::Message`.
109    pub fn from_message(message: Message) -> Self {
110        Self {
111            agent: String::new(),
112            sender: String::new(),
113            auto_injected: false,
114            message,
115        }
116    }
117
118    /// Mark this entry as auto-injected (chainable).
119    pub fn auto_injected(mut self) -> Self {
120        self.auto_injected = true;
121        self
122    }
123
124    /// The role of the underlying message.
125    pub fn role(&self) -> &Role {
126        &self.message.role
127    }
128
129    /// The text content of the message, or `""` if absent / empty / non-string.
130    pub fn text(&self) -> &str {
131        self.message.content_str().unwrap_or("")
132    }
133
134    /// The reasoning content, or empty if absent.
135    pub fn reasoning(&self) -> &str {
136        self.message.reasoning_content.as_deref().unwrap_or("")
137    }
138
139    /// The tool calls on this entry, or an empty slice if absent.
140    pub fn tool_calls(&self) -> &[ToolCall] {
141        self.message.tool_calls.as_deref().unwrap_or(&[])
142    }
143
144    /// The tool call ID on this (tool) entry, or empty if absent.
145    pub fn tool_call_id(&self) -> &str {
146        self.message.tool_call_id.as_deref().unwrap_or("")
147    }
148
149    /// Estimate the number of tokens in this entry (~4 chars per token).
150    pub fn estimate_tokens(&self) -> usize {
151        let chars = self.text().len()
152            + self.reasoning().len()
153            + self.tool_call_id().len()
154            + self
155                .tool_calls()
156                .iter()
157                .map(|tc| tc.function.name.len() + tc.function.arguments.len())
158                .sum::<usize>();
159        (chars / 4).max(1)
160    }
161
162    /// Project to a `crabllm_core::Message` for sending to a provider.
163    ///
164    /// If this is a guest assistant message (`agent` non-empty and role is
165    /// Assistant), wraps the content in `<from agent="...">` tags so other
166    /// agents can distinguish speakers in multi-agent conversations.
167    pub fn to_wire_message(&self) -> Message {
168        if self.message.role != Role::Assistant || self.agent.is_empty() {
169            return self.message.clone();
170        }
171        let tagged = format!("<from agent=\"{}\">\n{}\n</from>", self.agent, self.text());
172        Message {
173            role: Role::Assistant,
174            content: Some(serde_json::Value::String(tagged)),
175            tool_calls: self.message.tool_calls.clone(),
176            tool_call_id: self.message.tool_call_id.clone(),
177            name: self.message.name.clone(),
178            reasoning_content: self.message.reasoning_content.clone(),
179            extra: self.message.extra.clone(),
180        }
181    }
182}
183
184/// Estimate total tokens across a slice of entries.
185pub fn estimate_history_tokens(entries: &[HistoryEntry]) -> usize {
186    entries.iter().map(|e| e.estimate_tokens()).sum()
187}
188
189// ── MessageBuilder ──────────────────────────────────────────────────
190
191fn empty_tool_call() -> ToolCall {
192    ToolCall {
193        index: None,
194        id: String::new(),
195        kind: ToolType::Function,
196        function: FunctionCall::default(),
197    }
198}
199
200/// Accumulating builder for streaming assistant messages.
201pub struct MessageBuilder {
202    role: Role,
203    content: String,
204    reasoning: String,
205    calls: BTreeMap<u32, ToolCall>,
206}
207
208impl MessageBuilder {
209    /// Create a new builder for the given role (typically `Role::Assistant`).
210    pub fn new(role: Role) -> Self {
211        Self {
212            role,
213            content: String::new(),
214            reasoning: String::new(),
215            calls: BTreeMap::new(),
216        }
217    }
218
219    /// Accept one streaming chunk.
220    ///
221    /// Returns `true` if this chunk contributed visible text content.
222    pub fn accept(&mut self, chunk: &ChatCompletionChunk) -> bool {
223        let Some(choice) = chunk.choices.first() else {
224            return false;
225        };
226        let delta = &choice.delta;
227
228        let mut has_content = false;
229        if let Some(text) = delta.content.as_deref()
230            && !text.is_empty()
231        {
232            self.content.push_str(text);
233            has_content = true;
234        }
235        if let Some(reason) = delta.reasoning_content.as_deref()
236            && !reason.is_empty()
237        {
238            self.reasoning.push_str(reason);
239        }
240        if let Some(calls) = delta.tool_calls.as_deref() {
241            for call in calls {
242                self.merge_tool_call(call);
243            }
244        }
245        has_content
246    }
247
248    fn merge_tool_call(&mut self, delta: &ToolCallDelta) {
249        let entry = self
250            .calls
251            .entry(delta.index)
252            .or_insert_with(empty_tool_call);
253        entry.index = Some(delta.index);
254        if let Some(id) = &delta.id
255            && !id.is_empty()
256        {
257            entry.id = id.clone();
258        }
259        if let Some(kind) = delta.kind {
260            entry.kind = kind;
261        }
262        if let Some(function) = &delta.function {
263            if let Some(name) = &function.name
264                && !name.is_empty()
265            {
266                entry.function.name = name.clone();
267            }
268            if let Some(args) = &function.arguments {
269                entry.function.arguments.push_str(args);
270            }
271        }
272    }
273
274    /// Snapshot of tool calls accumulated so far.
275    pub fn peek_tool_calls(&self) -> Vec<ToolCall> {
276        self.calls
277            .values()
278            .filter(|c| !c.function.name.is_empty())
279            .cloned()
280            .collect()
281    }
282
283    /// Finalize the builder into a `crabllm_core::Message`.
284    pub fn build(self) -> Message {
285        let tool_calls: Vec<ToolCall> = self
286            .calls
287            .into_values()
288            .filter(|c| !c.id.is_empty() && !c.function.name.is_empty())
289            .collect();
290        let has_tool_calls = !tool_calls.is_empty();
291        let content = if self.content.is_empty() && has_tool_calls && self.role == Role::Assistant {
292            Some(serde_json::Value::Null)
293        } else {
294            Some(serde_json::Value::String(self.content))
295        };
296        let reasoning_content = if self.reasoning.is_empty() {
297            None
298        } else {
299            Some(self.reasoning)
300        };
301        Message {
302            role: self.role,
303            content,
304            tool_calls: if has_tool_calls {
305                Some(tool_calls)
306            } else {
307                None
308            },
309            tool_call_id: None,
310            name: None,
311            reasoning_content,
312            extra: Default::default(),
313        }
314    }
315}
316
317// ── Model<P> ────────────────────────────────────────────────────────
318
319/// A wcore-typed view over a `crabllm_core::Provider`.
320///
321/// Holds an `Arc<P>` so cloning is cheap. The `'static` bound on `P`
322/// flows from the streaming path.
323pub struct Model<P: Provider + 'static> {
324    inner: Arc<P>,
325}
326
327impl<P: Provider + 'static> Model<P> {
328    /// Wrap a provider in a `Model`.
329    pub fn new(provider: P) -> Self {
330        Self {
331            inner: Arc::new(provider),
332        }
333    }
334
335    /// Wrap an existing `Arc<P>` without re-allocating.
336    pub fn from_arc(provider: Arc<P>) -> Self {
337        Self { inner: provider }
338    }
339
340    /// Send a non-streaming chat completion request.
341    pub async fn send_ct(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse> {
342        let mut req = request;
343        req.stream = Some(false);
344        let model_label = req.model.clone();
345        self.inner
346            .chat_completion(&req)
347            .await
348            .map_err(|e| format_provider_error(&model_label, "send", e))
349    }
350
351    /// Stream a chat completion response.
352    pub fn stream_ct(
353        &self,
354        request: ChatCompletionRequest,
355    ) -> impl Stream<Item = Result<ChatCompletionChunk>> + Send + 'static {
356        let inner = Arc::clone(&self.inner);
357        let mut req = request;
358        req.stream = Some(true);
359        let model_label = req.model.clone();
360        try_stream! {
361            let mut stream = inner
362                .chat_completion_stream(&req)
363                .await
364                .map_err(|e| format_provider_error(&model_label, "stream open", e))?;
365            while let Some(chunk) = stream.next().await {
366                yield chunk
367                    .map_err(|e| format_provider_error(&model_label, "stream chunk", e))?;
368            }
369        }
370    }
371}
372
373impl<P: Provider + 'static> Clone for Model<P> {
374    fn clone(&self) -> Self {
375        Self {
376            inner: Arc::clone(&self.inner),
377        }
378    }
379}
380
381impl<P: Provider + 'static> std::fmt::Debug for Model<P> {
382    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383        f.debug_struct("Model").finish()
384    }
385}
386
387fn format_provider_error(model: &str, op: &str, e: crabllm_core::Error) -> anyhow::Error {
388    match e {
389        crabllm_core::Error::Provider { status, body } => {
390            let msg = serde_json::from_str::<ApiError>(&body)
391                .map(|api_err| api_err.error.message)
392                .unwrap_or_else(|_| truncate(&body, 200));
393            anyhow::anyhow!("model {op} failed for '{model}' (HTTP {status}): {msg}")
394        }
395        other => anyhow::anyhow!("model {op} failed for '{model}': {other}"),
396    }
397}
398
399fn truncate(s: &str, max: usize) -> String {
400    match s.char_indices().nth(max) {
401        Some((i, _)) => format!("{}...", &s[..i]),
402        None => s.to_string(),
403    }
404}
405
406// ── Context limits ──────────────────────────────────────────────────
407
408/// Returns the default context limit (in tokens) for a known model ID.
409///
410/// Uses prefix matching against known model families. Unknown models
411/// return 8192 as a conservative default.
412pub fn default_context_limit(model_id: &str) -> usize {
413    if model_id.starts_with("claude-") {
414        return 200_000;
415    }
416    if model_id.starts_with("gpt-4o") || model_id.starts_with("gpt-4-turbo") {
417        return 128_000;
418    }
419    if model_id.starts_with("gpt-4") {
420        return 8_192;
421    }
422    if model_id.starts_with("gpt-3.5") {
423        return 16_385;
424    }
425    if model_id.starts_with("o1") || model_id.starts_with("o3") || model_id.starts_with("o4") {
426        return 200_000;
427    }
428    if model_id.starts_with("grok-") {
429        return 131_072;
430    }
431    if model_id.starts_with("qwen-") || model_id.starts_with("qwq-") {
432        return 32_768;
433    }
434    if model_id.starts_with("kimi-") || model_id.starts_with("moonshot-") {
435        return 128_000;
436    }
437    8_192
438}