1use std::{
2 error::Error,
3 fmt::{Debug, Display},
4 sync::Arc,
5};
6
7use hanzo_ml::Tensor;
8#[cfg(feature = "pyo3_macros")]
9use pyo3::{pyclass, pymethods};
10use serde::Serialize;
11
12use crate::{sampler::TopLogprob, tools::ToolCallResponse};
13
14pub const SYSTEM_FINGERPRINT: &str = "local";
15
16macro_rules! generate_repr {
17 ($t:ident) => {
18 #[cfg(feature = "pyo3_macros")]
19 #[pymethods]
20 impl $t {
21 fn __repr__(&self) -> String {
22 format!("{self:#?}")
23 }
24 }
25 };
26}
27
28#[cfg_attr(feature = "pyo3_macros", pyclass)]
29#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
30#[derive(Debug, Clone, Serialize)]
31pub struct ResponseMessage {
33 pub content: Option<String>,
34 pub role: String,
35 pub tool_calls: Option<Vec<ToolCallResponse>>,
36 #[serde(skip_serializing_if = "Option::is_none")]
39 pub reasoning_content: Option<String>,
40}
41
42generate_repr!(ResponseMessage);
43
44#[cfg_attr(feature = "pyo3_macros", pyclass)]
45#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
46#[derive(Debug, Clone, Serialize)]
47pub struct Delta {
49 pub content: Option<String>,
50 pub role: String,
51 pub tool_calls: Option<Vec<ToolCallResponse>>,
52 #[serde(skip_serializing_if = "Option::is_none")]
55 pub reasoning_content: Option<String>,
56}
57
58generate_repr!(Delta);
59
60#[cfg_attr(feature = "pyo3_macros", pyclass)]
61#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
62#[derive(Debug, Clone, Serialize)]
63pub struct ResponseLogprob {
65 pub token: String,
66 pub logprob: f32,
67 pub bytes: Option<Vec<u8>>,
68 pub top_logprobs: Vec<TopLogprob>,
69}
70
71generate_repr!(ResponseLogprob);
72
73#[cfg_attr(feature = "pyo3_macros", pyclass)]
74#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
75#[derive(Debug, Clone, Serialize)]
76pub struct Logprobs {
78 pub content: Option<Vec<ResponseLogprob>>,
79}
80
81generate_repr!(Logprobs);
82
83#[cfg_attr(feature = "pyo3_macros", pyclass)]
84#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
85#[derive(Debug, Clone, Serialize)]
86pub struct Choice {
88 pub finish_reason: String,
89 pub index: usize,
90 pub message: ResponseMessage,
91 pub logprobs: Option<Logprobs>,
92}
93
94generate_repr!(Choice);
95
96#[cfg_attr(feature = "pyo3_macros", pyclass)]
97#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
98#[derive(Debug, Clone, Serialize)]
99pub struct ChunkChoice {
101 pub finish_reason: Option<String>,
102 pub index: usize,
103 pub delta: Delta,
104 pub logprobs: Option<ResponseLogprob>,
105}
106
107generate_repr!(ChunkChoice);
108
109#[cfg_attr(feature = "pyo3_macros", pyclass)]
110#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
111#[derive(Debug, Clone, Serialize)]
112pub struct CompletionChunkChoice {
114 pub text: String,
115 pub index: usize,
116 pub logprobs: Option<ResponseLogprob>,
117 pub finish_reason: Option<String>,
118}
119
120generate_repr!(CompletionChunkChoice);
121
122#[cfg_attr(feature = "pyo3_macros", pyclass)]
123#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
124#[derive(Debug, Clone, Serialize)]
125pub struct Usage {
127 pub completion_tokens: usize,
128 pub prompt_tokens: usize,
129 pub total_tokens: usize,
130 pub avg_tok_per_sec: f32,
131 pub avg_prompt_tok_per_sec: f32,
132 pub avg_compl_tok_per_sec: f32,
133 pub total_time_sec: f32,
134 pub total_prompt_time_sec: f32,
135 pub total_completion_time_sec: f32,
136}
137
138generate_repr!(Usage);
139
140#[cfg_attr(feature = "pyo3_macros", pyclass)]
141#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
142#[derive(Debug, Clone, Serialize)]
143pub struct AgenticToolCallRecord {
144 pub round: usize,
145 pub name: String,
146 pub arguments: String,
147 pub result_content: String,
148 #[serde(skip_serializing_if = "Vec::is_empty")]
150 pub result_images_base64: Vec<String>,
151 #[serde(skip_serializing_if = "Vec::is_empty")]
153 pub file_ids: Vec<String>,
154}
155
156generate_repr!(AgenticToolCallRecord);
157
158#[cfg_attr(feature = "pyo3_macros", pyclass)]
159#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
160#[derive(Debug, Clone, Serialize)]
161pub struct ChatCompletionResponse {
163 pub id: String,
164 pub choices: Vec<Choice>,
165 pub created: u64,
166 pub model: String,
167 pub system_fingerprint: String,
168 pub object: String,
169 pub usage: Usage,
170 #[serde(skip_serializing_if = "Option::is_none")]
172 pub agentic_tool_calls: Option<Vec<AgenticToolCallRecord>>,
173 #[serde(skip_serializing_if = "Option::is_none")]
175 pub files: Option<Vec<crate::files::File>>,
176 #[serde(skip_serializing_if = "Option::is_none")]
178 pub session_id: Option<String>,
179}
180
181generate_repr!(ChatCompletionResponse);
182
183#[cfg_attr(feature = "pyo3_macros", pyclass)]
184#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
185#[derive(Debug, Clone, Serialize)]
186pub struct ChatCompletionChunkResponse {
188 pub id: String,
189 pub choices: Vec<ChunkChoice>,
190 pub created: u128,
191 pub model: String,
192 pub system_fingerprint: String,
193 pub object: String,
194 pub usage: Option<Usage>,
195 #[serde(skip_serializing_if = "Option::is_none")]
197 pub session_id: Option<String>,
198}
199
200generate_repr!(ChatCompletionChunkResponse);
201
202#[cfg_attr(feature = "pyo3_macros", pyclass)]
203#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
204#[derive(Debug, Clone, Serialize)]
205pub struct CompletionChoice {
207 pub finish_reason: String,
208 pub index: usize,
209 pub text: String,
210 pub logprobs: Option<Logprobs>,
211}
212
213generate_repr!(CompletionChoice);
214
215#[cfg_attr(feature = "pyo3_macros", pyclass)]
216#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
217#[derive(Debug, Clone, Serialize)]
218pub struct CompletionResponse {
220 pub id: String,
221 pub choices: Vec<CompletionChoice>,
222 pub created: u64,
223 pub model: String,
224 pub system_fingerprint: String,
225 pub object: String,
226 pub usage: Usage,
227}
228
229generate_repr!(CompletionResponse);
230
231#[cfg_attr(feature = "pyo3_macros", pyclass)]
232#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
233#[derive(Debug, Clone, Serialize)]
234pub struct CompletionChunkResponse {
236 pub id: String,
237 pub choices: Vec<CompletionChunkChoice>,
238 pub created: u128,
239 pub model: String,
240 pub system_fingerprint: String,
241 pub object: String,
242}
243
244generate_repr!(CompletionChunkResponse);
245
246#[cfg_attr(feature = "pyo3_macros", pyclass)]
247#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
248#[derive(Debug, Clone, Serialize)]
249pub struct ImageChoice {
250 pub url: Option<String>,
251 pub b64_json: Option<String>,
252}
253
254generate_repr!(ImageChoice);
255
256#[cfg_attr(feature = "pyo3_macros", pyclass)]
257#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
258#[derive(Debug, Clone, Serialize)]
259pub struct ImageGenerationResponse {
260 pub created: u128,
261 pub data: Vec<ImageChoice>,
262}
263
264generate_repr!(ImageGenerationResponse);
265
266#[derive(Debug, Clone)]
268pub enum AgenticToolCallData {
269 CodeExecution {
271 code: Option<String>,
272 stdout: Option<String>,
273 stderr: Option<String>,
274 exception: Option<String>,
275 images: Vec<image::DynamicImage>,
276 video_frames: Vec<image::DynamicImage>,
277 video_frame_count: Option<usize>,
278 working_directory: Option<String>,
279 execution_time_ms: Option<u64>,
280 },
281 WebSearch {
283 query: Option<String>,
284 results_count: Option<usize>,
285 sources: Vec<String>,
286 },
287 Custom { arguments: String, content: String },
289}
290
291#[derive(Debug, Clone)]
293pub enum AgenticToolCallPhase {
294 Calling(AgenticToolCallData),
296 Complete(AgenticToolCallData),
298}
299
300pub enum Response {
305 InternalError(Box<dyn Error + Send + Sync>),
306 ValidationError(Box<dyn Error + Send + Sync>),
307 ModelError(String, ChatCompletionResponse),
309 Done(ChatCompletionResponse),
310 Chunk(ChatCompletionChunkResponse),
311 CompletionModelError(String, CompletionResponse),
313 CompletionDone(CompletionResponse),
314 CompletionChunk(CompletionChunkResponse),
315 ImageGeneration(ImageGenerationResponse),
317 Speech {
319 pcm: Arc<Vec<f32>>,
320 rate: usize,
321 channels: usize,
322 },
323 Raw {
325 logits_chunks: Vec<Tensor>,
326 tokens: Vec<u32>,
327 },
328 Embeddings {
329 embeddings: Vec<f32>,
330 prompt_tokens: usize,
331 total_tokens: usize,
332 },
333 AgenticToolCallProgress {
335 round: usize,
336 tool_name: String,
337 phase: AgenticToolCallPhase,
338 },
339 AgenticToolApprovalRequired {
340 approval_id: String,
341 session_id: String,
342 round: usize,
343 tool: crate::AgentToolMetadata,
344 arguments: serde_json::Value,
345 },
346 File(crate::files::File),
348}
349
350#[derive(Debug, Clone)]
351pub enum ResponseOk {
352 Done(ChatCompletionResponse),
354 Chunk(ChatCompletionChunkResponse),
355 CompletionDone(CompletionResponse),
357 CompletionChunk(CompletionChunkResponse),
358 ImageGeneration(ImageGenerationResponse),
360 Speech {
362 pcm: Arc<Vec<f32>>,
363 rate: usize,
364 channels: usize,
365 },
366 Raw {
368 logits_chunks: Vec<Tensor>,
369 tokens: Vec<u32>,
370 },
371 Embeddings {
373 embeddings: Vec<f32>,
374 prompt_tokens: usize,
375 total_tokens: usize,
376 },
377 AgenticToolCallProgress {
379 round: usize,
380 tool_name: String,
381 phase: AgenticToolCallPhase,
382 },
383 AgenticToolApprovalRequired {
384 approval_id: String,
385 session_id: String,
386 round: usize,
387 tool: crate::AgentToolMetadata,
388 arguments: serde_json::Value,
389 },
390 File(crate::files::File),
391}
392
393pub enum ResponseErr {
394 InternalError(Box<dyn Error + Send + Sync>),
395 ValidationError(Box<dyn Error + Send + Sync>),
396 ModelError(String, ChatCompletionResponse),
397 CompletionModelError(String, CompletionResponse),
398}
399
400impl Display for ResponseErr {
401 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402 match self {
403 Self::InternalError(e) | Self::ValidationError(e) => Display::fmt(e, f),
404 Self::ModelError(e, x) => f
405 .debug_struct("ChatModelError")
406 .field("msg", e)
407 .field("incomplete_response", x)
408 .finish(),
409 Self::CompletionModelError(e, x) => f
410 .debug_struct("CompletionModelError")
411 .field("msg", e)
412 .field("incomplete_response", x)
413 .finish(),
414 }
415 }
416}
417
418impl Debug for ResponseErr {
419 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
420 match self {
421 Self::InternalError(e) | Self::ValidationError(e) => Debug::fmt(e, f),
422 Self::ModelError(e, x) => f
423 .debug_struct("ChatModelError")
424 .field("msg", e)
425 .field("incomplete_response", x)
426 .finish(),
427 Self::CompletionModelError(e, x) => f
428 .debug_struct("CompletionModelError")
429 .field("msg", e)
430 .field("incomplete_response", x)
431 .finish(),
432 }
433 }
434}
435
436impl std::error::Error for ResponseErr {}
437
438impl Response {
439 pub fn as_result(self) -> Result<ResponseOk, Box<ResponseErr>> {
441 match self {
442 Self::Done(x) => Ok(ResponseOk::Done(x)),
443 Self::Chunk(x) => Ok(ResponseOk::Chunk(x)),
444 Self::CompletionDone(x) => Ok(ResponseOk::CompletionDone(x)),
445 Self::CompletionChunk(x) => Ok(ResponseOk::CompletionChunk(x)),
446 Self::InternalError(e) => Err(Box::new(ResponseErr::InternalError(e))),
447 Self::ValidationError(e) => Err(Box::new(ResponseErr::ValidationError(e))),
448 Self::ModelError(e, x) => Err(Box::new(ResponseErr::ModelError(e, x))),
449 Self::CompletionModelError(e, x) => {
450 Err(Box::new(ResponseErr::CompletionModelError(e, x)))
451 }
452 Self::ImageGeneration(x) => Ok(ResponseOk::ImageGeneration(x)),
453 Self::Speech {
454 pcm,
455 rate,
456 channels,
457 } => Ok(ResponseOk::Speech {
458 pcm,
459 rate,
460 channels,
461 }),
462 Self::Raw {
463 logits_chunks,
464 tokens,
465 } => Ok(ResponseOk::Raw {
466 logits_chunks,
467 tokens,
468 }),
469 Self::Embeddings {
470 embeddings,
471 prompt_tokens,
472 total_tokens,
473 } => Ok(ResponseOk::Embeddings {
474 embeddings,
475 prompt_tokens,
476 total_tokens,
477 }),
478 Self::AgenticToolCallProgress {
479 round,
480 tool_name,
481 phase,
482 } => Ok(ResponseOk::AgenticToolCallProgress {
483 round,
484 tool_name,
485 phase,
486 }),
487 Self::AgenticToolApprovalRequired {
488 approval_id,
489 session_id,
490 round,
491 tool,
492 arguments,
493 } => Ok(ResponseOk::AgenticToolApprovalRequired {
494 approval_id,
495 session_id,
496 round,
497 tool,
498 arguments,
499 }),
500 Self::File(f) => Ok(ResponseOk::File(f)),
501 }
502 }
503}