Skip to main content

tt_shared/
messages.rs

1//! OpenAI-compatible request/response shapes. Canonical wire format across all
2//! providers — adapters translate to/from provider-native formats.
3
4use std::collections::HashMap;
5
6use serde::{Deserialize, Serialize};
7
8use crate::Usage;
9
10// ---------------------------------------------------------------------------
11// tt_extras cache-control types (Fix B / §2.7)
12// ---------------------------------------------------------------------------
13
14/// Cache behaviour requested by the caller via `tt_extras.cache`.
15///
16/// Absent (no `cache` key in `tt_extras`) is treated as [`CacheMode::Normal`].
17///
18/// JSON shape:
19/// ```json
20/// { "cache": { "mode": "bypass", "ttl_secs": 3600 } }
21/// ```
22#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
23#[serde(rename_all = "lowercase")]
24pub enum CacheMode {
25    /// Normal read-write caching (default when key absent).
26    #[default]
27    Normal,
28    /// Skip lookup AND insert — always hit the provider, never populate cache.
29    Bypass,
30    /// Skip lookup, but DO insert (force-refresh stale entry).
31    Refresh,
32    /// Do lookup, but never insert (read-only cache consumer).
33    #[serde(rename = "read-only")]
34    ReadOnly,
35}
36
37/// Typed cache-control extracted from `tt_extras`.
38#[derive(Debug, Clone, Default, Serialize, Deserialize)]
39pub struct CacheControlConfig {
40    /// Requested cache behaviour.
41    #[serde(default)]
42    pub mode: CacheMode,
43    /// Override TTL for cache inserts. `None` = use the gateway default.
44    #[serde(default, skip_serializing_if = "Option::is_none")]
45    pub ttl_secs: Option<u64>,
46}
47
48/// Parse [`CacheControlConfig`] from a request's `tt_extras` map.
49///
50/// Returns `None` when `tt_extras` does not contain a `"cache"` key.
51/// Returns the default config (normal mode, no TTL override) when the key is
52/// present but the value fails to deserialize — so a malformed field degrades
53/// gracefully rather than hard-failing.
54pub fn parse_cache_control(
55    extras: &HashMap<String, serde_json::Value>,
56) -> Option<CacheControlConfig> {
57    let val = extras.get("cache")?;
58    match serde_json::from_value::<CacheControlConfig>(val.clone()) {
59        Ok(cfg) => Some(cfg),
60        Err(e) => {
61            // Log at warn level so operators can see bad payloads; fall back
62            // to normal (don't block the request).
63            tracing::warn!(
64                error = %e,
65                "tt_extras.cache deserialization failed — treating as normal"
66            );
67            Some(CacheControlConfig::default())
68        }
69    }
70}
71
72#[cfg(test)]
73mod cache_control_tests {
74    use super::*;
75
76    fn extras(json: &str) -> HashMap<String, serde_json::Value> {
77        serde_json::from_str(json).unwrap()
78    }
79
80    #[test]
81    fn no_cache_key_returns_none() {
82        assert!(parse_cache_control(&extras("{}")).is_none());
83    }
84
85    #[test]
86    fn bypass_mode_parsed() {
87        let cfg = parse_cache_control(&extras(r#"{"cache":{"mode":"bypass"}}"#)).unwrap();
88        assert_eq!(cfg.mode, CacheMode::Bypass);
89        assert!(cfg.ttl_secs.is_none());
90    }
91
92    #[test]
93    fn refresh_mode_with_ttl() {
94        let cfg = parse_cache_control(&extras(r#"{"cache":{"mode":"refresh","ttl_secs":3600}}"#))
95            .unwrap();
96        assert_eq!(cfg.mode, CacheMode::Refresh);
97        assert_eq!(cfg.ttl_secs, Some(3600));
98    }
99
100    #[test]
101    fn read_only_mode() {
102        let cfg = parse_cache_control(&extras(r#"{"cache":{"mode":"read-only"}}"#)).unwrap();
103        assert_eq!(cfg.mode, CacheMode::ReadOnly);
104    }
105
106    #[test]
107    fn absent_mode_defaults_to_normal() {
108        let cfg = parse_cache_control(&extras(r#"{"cache":{}}"#)).unwrap();
109        assert_eq!(cfg.mode, CacheMode::Normal);
110    }
111
112    #[test]
113    fn malformed_value_falls_back_to_default() {
114        let cfg = parse_cache_control(&extras(r#"{"cache":"not-an-object"}"#)).unwrap();
115        assert_eq!(cfg.mode, CacheMode::Normal);
116    }
117}
118
119#[derive(Debug, Clone, Default, Serialize, Deserialize)]
120pub struct ChatCompletionRequest {
121    pub model: String,
122    pub messages: Vec<Message>,
123
124    #[serde(default, skip_serializing_if = "Option::is_none")]
125    pub temperature: Option<f32>,
126    #[serde(default, skip_serializing_if = "Option::is_none")]
127    pub top_p: Option<f32>,
128    #[serde(default, skip_serializing_if = "Option::is_none")]
129    pub max_tokens: Option<u32>,
130    #[serde(default, skip_serializing_if = "std::ops::Not::not")]
131    pub stream: bool,
132    #[serde(default, skip_serializing_if = "Vec::is_empty")]
133    pub tools: Vec<Tool>,
134    #[serde(default, skip_serializing_if = "Option::is_none")]
135    pub tool_choice: Option<ToolChoice>,
136    #[serde(default, skip_serializing_if = "Option::is_none")]
137    pub response_format: Option<ResponseFormat>,
138    #[serde(default, skip_serializing_if = "Vec::is_empty")]
139    pub stop: Vec<String>,
140    #[serde(default, skip_serializing_if = "Option::is_none")]
141    pub presence_penalty: Option<f32>,
142    #[serde(default, skip_serializing_if = "Option::is_none")]
143    pub frequency_penalty: Option<f32>,
144    #[serde(default, skip_serializing_if = "Option::is_none")]
145    pub n: Option<u32>,
146    #[serde(default, skip_serializing_if = "Option::is_none")]
147    pub seed: Option<i64>,
148    #[serde(default, skip_serializing_if = "Option::is_none")]
149    pub user: Option<String>,
150
151    /// TokenTrimmer-internal extras (cache config, route hints, etc.) that are
152    /// stripped before forwarding to the provider.
153    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
154    pub tt_extras: HashMap<String, serde_json::Value>,
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
158#[serde(tag = "role", rename_all = "lowercase")]
159pub enum Message {
160    System {
161        content: MessageContent,
162    },
163    User {
164        content: MessageContent,
165        #[serde(default, skip_serializing_if = "Option::is_none")]
166        name: Option<String>,
167    },
168    Assistant {
169        #[serde(default, skip_serializing_if = "Option::is_none")]
170        content: Option<MessageContent>,
171        #[serde(default, skip_serializing_if = "Vec::is_empty")]
172        tool_calls: Vec<ToolCall>,
173        #[serde(default, skip_serializing_if = "Option::is_none")]
174        name: Option<String>,
175    },
176    Tool {
177        content: MessageContent,
178        tool_call_id: String,
179    },
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
183#[serde(untagged)]
184pub enum MessageContent {
185    Text(String),
186    Parts(Vec<ContentPart>),
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
190#[serde(tag = "type", rename_all = "snake_case")]
191pub enum ContentPart {
192    Text { text: String },
193    ImageUrl { image_url: ImageUrl },
194    InputAudio { input_audio: InputAudio },
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct ImageUrl {
199    pub url: String,
200    #[serde(default, skip_serializing_if = "Option::is_none")]
201    pub detail: Option<String>,
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct InputAudio {
206    pub data: String,
207    pub format: String,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct Tool {
212    #[serde(rename = "type")]
213    pub r#type: String,
214    pub function: ToolFunction,
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct ToolFunction {
219    pub name: String,
220    #[serde(default, skip_serializing_if = "Option::is_none")]
221    pub description: Option<String>,
222    pub parameters: serde_json::Value,
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
226#[serde(untagged)]
227pub enum ToolChoice {
228    Auto(String),
229    Specific {
230        #[serde(rename = "type")]
231        r#type: String,
232        function: ToolChoiceFunction,
233    },
234}
235
236impl ToolChoice {
237    /// Let the model decide whether to call a tool (`"auto"`).
238    #[must_use]
239    pub fn auto() -> Self {
240        ToolChoice::Auto("auto".to_string())
241    }
242
243    /// Forbid tool calls — force a plain text answer (`"none"`).
244    #[must_use]
245    pub fn none() -> Self {
246        ToolChoice::Auto("none".to_string())
247    }
248
249    /// Require the model to call some tool (`"required"`).
250    #[must_use]
251    pub fn required() -> Self {
252        ToolChoice::Auto("required".to_string())
253    }
254
255    /// Require the model to call a specific named function.
256    #[must_use]
257    pub fn function(name: impl Into<String>) -> Self {
258        ToolChoice::Specific {
259            r#type: "function".to_string(),
260            function: ToolChoiceFunction { name: name.into() },
261        }
262    }
263}
264
265#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct ToolChoiceFunction {
267    pub name: String,
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct ToolCall {
272    pub id: String,
273    #[serde(rename = "type")]
274    pub r#type: String,
275    pub function: ToolCallFunction,
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct ToolCallFunction {
280    pub name: String,
281    /// Stringified JSON arguments — OpenAI convention.
282    pub arguments: String,
283}
284
285#[derive(Debug, Clone, Serialize, Deserialize)]
286pub struct ResponseFormat {
287    #[serde(rename = "type")]
288    pub r#type: String,
289    #[serde(default, skip_serializing_if = "Option::is_none")]
290    pub json_schema: Option<serde_json::Value>,
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct ChatCompletionResponse {
295    pub id: String,
296    pub object: String,
297    pub created: i64,
298    pub model: String,
299    pub choices: Vec<Choice>,
300    pub usage: Usage,
301}
302
303#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct Choice {
305    pub index: u32,
306    pub message: Message,
307    pub finish_reason: Option<String>,
308}
309
310/// One SSE event from a streaming chat completion.
311#[derive(Debug, Clone, Serialize, Deserialize)]
312pub struct ChatCompletionChunk {
313    pub id: String,
314    pub object: String,
315    pub created: i64,
316    pub model: String,
317    pub choices: Vec<ChunkChoice>,
318    #[serde(default, skip_serializing_if = "Option::is_none")]
319    pub usage: Option<Usage>,
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
323pub struct ChunkChoice {
324    pub index: u32,
325    pub delta: ChunkDelta,
326    pub finish_reason: Option<String>,
327}
328
329#[derive(Debug, Clone, Default, Serialize, Deserialize)]
330pub struct ChunkDelta {
331    #[serde(default, skip_serializing_if = "Option::is_none")]
332    pub role: Option<String>,
333    #[serde(default, skip_serializing_if = "Option::is_none")]
334    pub content: Option<String>,
335    #[serde(default, skip_serializing_if = "Vec::is_empty")]
336    pub tool_calls: Vec<ToolCall>,
337}
338
339#[derive(Debug, Clone, Serialize, Deserialize)]
340pub struct EmbeddingsRequest {
341    pub model: String,
342    pub input: EmbeddingInput,
343    #[serde(default, skip_serializing_if = "Option::is_none")]
344    pub dimensions: Option<u32>,
345    #[serde(default, skip_serializing_if = "Option::is_none")]
346    pub encoding_format: Option<String>,
347}
348
349#[derive(Debug, Clone, Serialize, Deserialize)]
350#[serde(untagged)]
351pub enum EmbeddingInput {
352    Single(String),
353    Batch(Vec<String>),
354}
355
356#[derive(Debug, Clone, Serialize, Deserialize)]
357pub struct EmbeddingsResponse {
358    pub object: String,
359    pub data: Vec<EmbeddingData>,
360    pub model: String,
361    pub usage: Usage,
362}
363
364#[derive(Debug, Clone, Serialize, Deserialize)]
365pub struct EmbeddingData {
366    pub object: String,
367    pub index: u32,
368    pub embedding: Vec<f32>,
369}
370
371/// Parse a base64 `data:` URL into `(media_type, base64_payload)`.
372///
373/// Returns `None` for non-`data:` URLs, non-base64 data URLs, or a malformed/
374/// empty media type. Provider adapters use this to forward inline image bytes
375/// as the provider's native base64 image part instead of mistakenly sending the
376/// whole `data:` URI as a *remote* URL reference (which the upstream rejects).
377#[must_use]
378pub fn parse_data_url(url: &str) -> Option<(String, String)> {
379    let rest = url.strip_prefix("data:")?;
380    let (meta, data) = rest.split_once(',')?;
381    // Only base64 payloads are supported (the canonical image transport).
382    let media_with_params = meta.strip_suffix(";base64")?;
383    // Drop any RFC-2397 media-type parameters (e.g. `;charset=utf-8`) — providers
384    // expect a bare MIME type like `image/png` in the base64 image part.
385    let media_type = media_with_params.split(';').next().unwrap_or("");
386    if media_type.is_empty() || data.is_empty() {
387        return None;
388    }
389    Some((media_type.to_string(), data.to_string()))
390}
391
392#[cfg(test)]
393mod embeddings_default_tests {
394    use super::*;
395
396    #[test]
397    fn chat_request_default_is_empty() {
398        let r = ChatCompletionRequest::default();
399        assert_eq!(r.model, "");
400        assert!(r.messages.is_empty());
401        assert!(!r.stream);
402        assert!(r.tools.is_empty());
403        assert!(r.max_tokens.is_none());
404    }
405
406    #[test]
407    fn parse_data_url_extracts_media_type_and_payload() {
408        assert_eq!(
409            parse_data_url("data:image/png;base64,iVBORw0KGgo="),
410            Some(("image/png".to_string(), "iVBORw0KGgo=".to_string()))
411        );
412        // Non-data URLs and non-base64 / malformed data URLs return None.
413        assert_eq!(parse_data_url("https://example.com/cat.png"), None);
414        assert_eq!(parse_data_url("data:image/png,notbase64"), None);
415        assert_eq!(parse_data_url("data:;base64,abc"), None);
416        assert_eq!(parse_data_url("data:image/png;base64,"), None);
417        // Media-type parameters are stripped to a bare MIME type.
418        assert_eq!(
419            parse_data_url("data:image/png;charset=utf-8;base64,iVBORw0KGgo="),
420            Some(("image/png".to_string(), "iVBORw0KGgo=".to_string()))
421        );
422    }
423
424    #[test]
425    fn tool_choice_constructors_serialize_to_the_wire_form() {
426        // The string variants stay an untagged bare string …
427        assert_eq!(
428            serde_json::to_value(ToolChoice::auto()).unwrap(),
429            serde_json::json!("auto")
430        );
431        assert_eq!(
432            serde_json::to_value(ToolChoice::none()).unwrap(),
433            serde_json::json!("none")
434        );
435        assert_eq!(
436            serde_json::to_value(ToolChoice::required()).unwrap(),
437            serde_json::json!("required")
438        );
439        // … and `function(name)` produces the object form.
440        assert_eq!(
441            serde_json::to_value(ToolChoice::function("get_weather")).unwrap(),
442            serde_json::json!({ "type": "function", "function": { "name": "get_weather" } })
443        );
444    }
445}