1pub use crabllm_core::{
9 ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, CompletionTokensDetails,
10 FinishReason, FunctionCall, FunctionDef, Message, Role, Tool, ToolCall, ToolCallDelta,
11 ToolChoice, ToolType, Usage,
12};
13
14use anyhow::Result;
15use async_stream::try_stream;
16use crabllm_core::{ApiError, Provider};
17use futures_core::Stream;
18use futures_util::StreamExt;
19use serde::{Deserialize, Serialize};
20use std::{collections::BTreeMap, sync::Arc};
21
22#[derive(Debug, Clone, Deserialize, Serialize)]
31pub struct HistoryEntry {
32 #[serde(default, skip_serializing_if = "String::is_empty")]
36 pub agent: String,
37
38 #[serde(skip)]
40 pub sender: String,
41
42 #[serde(skip)]
46 pub auto_injected: bool,
47
48 pub message: Message,
50}
51
52impl HistoryEntry {
53 pub fn system(content: impl Into<String>) -> Self {
55 Self::from_message(Message::system(content))
56 }
57
58 pub fn user(content: impl Into<String>) -> Self {
60 Self::from_message(Message::user(content))
61 }
62
63 pub fn user_with_sender(content: impl Into<String>, sender: impl Into<String>) -> Self {
65 let mut entry = Self::user(content);
66 entry.sender = sender.into();
67 entry
68 }
69
70 pub fn assistant(
77 content: impl Into<String>,
78 reasoning: Option<String>,
79 tool_calls: Option<&[ToolCall]>,
80 ) -> Self {
81 let content: String = content.into();
82 let has_tool_calls = tool_calls.is_some_and(|tcs| !tcs.is_empty());
83 let message_content = if content.is_empty() && has_tool_calls {
84 Some(serde_json::Value::Null)
85 } else {
86 Some(serde_json::Value::String(content))
87 };
88 Self::from_message(Message {
89 role: Role::Assistant,
90 content: message_content,
91 tool_calls: tool_calls.map(|tcs| tcs.to_vec()),
92 tool_call_id: None,
93 name: None,
94 reasoning_content: reasoning.filter(|s| !s.is_empty()),
95 extra: Default::default(),
96 })
97 }
98
99 pub fn tool(
101 content: impl Into<String>,
102 call_id: impl Into<String>,
103 name: impl Into<String>,
104 ) -> Self {
105 Self::from_message(Message::tool(call_id, name, content))
106 }
107
108 pub fn from_message(message: Message) -> Self {
110 Self {
111 agent: String::new(),
112 sender: String::new(),
113 auto_injected: false,
114 message,
115 }
116 }
117
118 pub fn auto_injected(mut self) -> Self {
120 self.auto_injected = true;
121 self
122 }
123
124 pub fn role(&self) -> &Role {
126 &self.message.role
127 }
128
129 pub fn text(&self) -> &str {
131 self.message.content_str().unwrap_or("")
132 }
133
134 pub fn reasoning(&self) -> &str {
136 self.message.reasoning_content.as_deref().unwrap_or("")
137 }
138
139 pub fn tool_calls(&self) -> &[ToolCall] {
141 self.message.tool_calls.as_deref().unwrap_or(&[])
142 }
143
144 pub fn tool_call_id(&self) -> &str {
146 self.message.tool_call_id.as_deref().unwrap_or("")
147 }
148
149 pub fn estimate_tokens(&self) -> usize {
151 let chars = self.text().len()
152 + self.reasoning().len()
153 + self.tool_call_id().len()
154 + self
155 .tool_calls()
156 .iter()
157 .map(|tc| tc.function.name.len() + tc.function.arguments.len())
158 .sum::<usize>();
159 (chars / 4).max(1)
160 }
161
162 pub fn to_wire_message(&self) -> Message {
168 if self.message.role != Role::Assistant || self.agent.is_empty() {
169 return self.message.clone();
170 }
171 let tagged = format!("<from agent=\"{}\">\n{}\n</from>", self.agent, self.text());
172 Message {
173 role: Role::Assistant,
174 content: Some(serde_json::Value::String(tagged)),
175 tool_calls: self.message.tool_calls.clone(),
176 tool_call_id: self.message.tool_call_id.clone(),
177 name: self.message.name.clone(),
178 reasoning_content: self.message.reasoning_content.clone(),
179 extra: self.message.extra.clone(),
180 }
181 }
182}
183
184pub fn estimate_history_tokens(entries: &[HistoryEntry]) -> usize {
186 entries.iter().map(|e| e.estimate_tokens()).sum()
187}
188
189fn empty_tool_call() -> ToolCall {
192 ToolCall {
193 index: None,
194 id: String::new(),
195 kind: ToolType::Function,
196 function: FunctionCall::default(),
197 }
198}
199
200pub struct MessageBuilder {
202 role: Role,
203 content: String,
204 reasoning: String,
205 calls: BTreeMap<u32, ToolCall>,
206}
207
208impl MessageBuilder {
209 pub fn new(role: Role) -> Self {
211 Self {
212 role,
213 content: String::new(),
214 reasoning: String::new(),
215 calls: BTreeMap::new(),
216 }
217 }
218
219 pub fn accept(&mut self, chunk: &ChatCompletionChunk) -> bool {
223 let Some(choice) = chunk.choices.first() else {
224 return false;
225 };
226 let delta = &choice.delta;
227
228 let mut has_content = false;
229 if let Some(text) = delta.content.as_deref()
230 && !text.is_empty()
231 {
232 self.content.push_str(text);
233 has_content = true;
234 }
235 if let Some(reason) = delta.reasoning_content.as_deref()
236 && !reason.is_empty()
237 {
238 self.reasoning.push_str(reason);
239 }
240 if let Some(calls) = delta.tool_calls.as_deref() {
241 for call in calls {
242 self.merge_tool_call(call);
243 }
244 }
245 has_content
246 }
247
248 fn merge_tool_call(&mut self, delta: &ToolCallDelta) {
249 let entry = self
250 .calls
251 .entry(delta.index)
252 .or_insert_with(empty_tool_call);
253 entry.index = Some(delta.index);
254 if let Some(id) = &delta.id
255 && !id.is_empty()
256 {
257 entry.id = id.clone();
258 }
259 if let Some(kind) = delta.kind {
260 entry.kind = kind;
261 }
262 if let Some(function) = &delta.function {
263 if let Some(name) = &function.name
264 && !name.is_empty()
265 {
266 entry.function.name = name.clone();
267 }
268 if let Some(args) = &function.arguments {
269 entry.function.arguments.push_str(args);
270 }
271 }
272 }
273
274 pub fn peek_tool_calls(&self) -> Vec<ToolCall> {
276 self.calls
277 .values()
278 .filter(|c| !c.function.name.is_empty())
279 .cloned()
280 .collect()
281 }
282
283 pub fn build(self) -> Message {
285 let tool_calls: Vec<ToolCall> = self
286 .calls
287 .into_values()
288 .filter(|c| !c.id.is_empty() && !c.function.name.is_empty())
289 .collect();
290 let has_tool_calls = !tool_calls.is_empty();
291 let content = if self.content.is_empty() && has_tool_calls && self.role == Role::Assistant {
292 Some(serde_json::Value::Null)
293 } else {
294 Some(serde_json::Value::String(self.content))
295 };
296 let reasoning_content = if self.reasoning.is_empty() {
297 None
298 } else {
299 Some(self.reasoning)
300 };
301 Message {
302 role: self.role,
303 content,
304 tool_calls: if has_tool_calls {
305 Some(tool_calls)
306 } else {
307 None
308 },
309 tool_call_id: None,
310 name: None,
311 reasoning_content,
312 extra: Default::default(),
313 }
314 }
315}
316
317pub struct Model<P: Provider + 'static> {
324 inner: Arc<P>,
325}
326
327impl<P: Provider + 'static> Model<P> {
328 pub fn new(provider: P) -> Self {
330 Self {
331 inner: Arc::new(provider),
332 }
333 }
334
335 pub fn from_arc(provider: Arc<P>) -> Self {
337 Self { inner: provider }
338 }
339
340 pub async fn send_ct(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse> {
342 let mut req = request;
343 req.stream = Some(false);
344 let model_label = req.model.clone();
345 self.inner
346 .chat_completion(&req)
347 .await
348 .map_err(|e| format_provider_error(&model_label, "send", e))
349 }
350
351 pub fn stream_ct(
353 &self,
354 request: ChatCompletionRequest,
355 ) -> impl Stream<Item = Result<ChatCompletionChunk>> + Send + 'static {
356 let inner = Arc::clone(&self.inner);
357 let mut req = request;
358 req.stream = Some(true);
359 let model_label = req.model.clone();
360 try_stream! {
361 let mut stream = inner
362 .chat_completion_stream(&req)
363 .await
364 .map_err(|e| format_provider_error(&model_label, "stream open", e))?;
365 while let Some(chunk) = stream.next().await {
366 yield chunk
367 .map_err(|e| format_provider_error(&model_label, "stream chunk", e))?;
368 }
369 }
370 }
371}
372
373impl<P: Provider + 'static> Clone for Model<P> {
374 fn clone(&self) -> Self {
375 Self {
376 inner: Arc::clone(&self.inner),
377 }
378 }
379}
380
381impl<P: Provider + 'static> std::fmt::Debug for Model<P> {
382 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383 f.debug_struct("Model").finish()
384 }
385}
386
387fn format_provider_error(model: &str, op: &str, e: crabllm_core::Error) -> anyhow::Error {
388 match e {
389 crabllm_core::Error::Provider { status, body } => {
390 let msg = serde_json::from_str::<ApiError>(&body)
391 .map(|api_err| api_err.error.message)
392 .unwrap_or_else(|_| truncate(&body, 200));
393 anyhow::anyhow!("model {op} failed for '{model}' (HTTP {status}): {msg}")
394 }
395 other => anyhow::anyhow!("model {op} failed for '{model}': {other}"),
396 }
397}
398
399fn truncate(s: &str, max: usize) -> String {
400 match s.char_indices().nth(max) {
401 Some((i, _)) => format!("{}...", &s[..i]),
402 None => s.to_string(),
403 }
404}
405
406pub fn default_context_limit(model_id: &str) -> usize {
413 if model_id.starts_with("claude-") {
414 return 200_000;
415 }
416 if model_id.starts_with("gpt-4o") || model_id.starts_with("gpt-4-turbo") {
417 return 128_000;
418 }
419 if model_id.starts_with("gpt-4") {
420 return 8_192;
421 }
422 if model_id.starts_with("gpt-3.5") {
423 return 16_385;
424 }
425 if model_id.starts_with("o1") || model_id.starts_with("o3") || model_id.starts_with("o4") {
426 return 200_000;
427 }
428 if model_id.starts_with("grok-") {
429 return 131_072;
430 }
431 if model_id.starts_with("qwen-") || model_id.starts_with("qwq-") {
432 return 32_768;
433 }
434 if model_id.starts_with("kimi-") || model_id.starts_with("moonshot-") {
435 return 128_000;
436 }
437 8_192
438}