chronicle_proxy/
format.rs

1//! The common format for requests. This mostly hews to the OpenAI format except
2//! some response fields are made optional to accomodate different model providers.
3
4use std::{borrow::Cow, collections::BTreeMap};
5
6use error_stack::Report;
7use serde::{Deserialize, Serialize};
8use serde_with::{formats::PreferMany, serde_as, OneOrMany};
9use uuid::Uuid;
10
11use crate::providers::ProviderError;
12
13/// A chat response, in non-chunked format
14#[derive(Serialize, Deserialize, Debug, Clone)]
15pub struct ChatResponse<CHOICE> {
16    // Omitted certain fields that aren't really useful
17    // id: String,
18    // object: String,
19    /// Unix timestamp in seconds
20    pub created: u64,
21    /// The model that was used
22    pub model: Option<String>,
23    /// A fingerprint for the system prompt
24    pub system_fingerprint: Option<String>,
25    /// The chat choices
26    pub choices: Vec<CHOICE>,
27    /// Token usage information
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub usage: Option<UsageResponse>,
30}
31
32/// A chunk of streaming response
33pub type StreamingChatResponse = ChatResponse<ChatChoiceDelta>;
34/// A non-streaming chat response
35pub type SingleChatResponse = ChatResponse<ChatChoice>;
36
37impl ChatResponse<ChatChoice> {
38    ///Create a new, empty ChatResponse designed for collecting streaming chat responses.
39    pub fn new_for_collection(num_choices: usize) -> Self {
40        SingleChatResponse {
41            created: 0,
42            model: None,
43            system_fingerprint: None,
44            choices: Vec::with_capacity(num_choices),
45            usage: Some(UsageResponse {
46                prompt_tokens: None,
47                completion_tokens: None,
48                total_tokens: None,
49            }),
50        }
51    }
52
53    /// Merge a streaming delta into this response.
54    pub fn merge_delta(&mut self, chunk: &ChatResponse<ChatChoiceDelta>) {
55        if self.created == 0 {
56            self.created = chunk.created;
57        }
58
59        if self.model.is_none() {
60            self.model = chunk.model.clone();
61        }
62
63        if self.system_fingerprint.is_none() {
64            self.system_fingerprint = chunk.system_fingerprint.clone();
65        }
66
67        if let Some(delta_usage) = chunk.usage.as_ref() {
68            if let Some(usage) = self.usage.as_mut() {
69                usage.merge(delta_usage);
70            } else {
71                self.usage = chunk.usage.clone();
72            }
73        }
74
75        for choice in chunk.choices.iter() {
76            if choice.index >= self.choices.len() {
77                // Resize to either the index mentioned here, or the total number of choices in
78                // this message. This way we only resize once.
79                let new_size = std::cmp::max(chunk.choices.len(), choice.index + 1);
80                self.choices.resize(new_size, ChatChoice::default());
81
82                for i in 0..self.choices.len() {
83                    self.choices[i].index = i;
84                }
85            }
86
87            let c = &mut self.choices[choice.index];
88            c.message.add_delta(&choice.delta);
89
90            if let Some(finish) = choice.finish_reason.as_ref() {
91                c.finish_reason = finish.clone();
92            }
93        }
94    }
95}
96
97/// For when we need to make a non-streaming chat response appear like it was a streaming response
98impl From<SingleChatResponse> for StreamingChatResponse {
99    fn from(value: SingleChatResponse) -> Self {
100        ChatResponse {
101            created: value.created,
102            model: value.model,
103            system_fingerprint: value.system_fingerprint,
104            choices: value
105                .choices
106                .into_iter()
107                .map(|c| ChatChoiceDelta {
108                    index: c.index,
109                    delta: c.message,
110                    finish_reason: Some(c.finish_reason),
111                })
112                .collect(),
113            usage: value.usage,
114        }
115    }
116}
117
118/// A single choice in a chat
119#[derive(Serialize, Deserialize, Default, Debug, Clone)]
120pub struct ChatChoice {
121    /// Which choice this is
122    pub index: usize,
123    /// The message
124    pub message: ChatMessage,
125    /// The reason the chat terminated
126    pub finish_reason: FinishReason,
127}
128
129/// A delta in a streaming chat choice
130#[derive(Serialize, Deserialize, Default, Debug, Clone)]
131pub struct ChatChoiceDelta {
132    /// Which choice this is
133    pub index: usize,
134    /// The message
135    pub delta: ChatMessage,
136    /// The reason the chat terminated, if this is the final delta in the choice
137    pub finish_reason: Option<FinishReason>,
138}
139
140#[derive(Serialize, Deserialize, Default, Debug, Clone)]
141#[serde(rename_all = "snake_case")]
142pub enum FinishReason {
143    #[default]
144    Stop,
145    Length,
146    ContentFilter,
147    ToolCalls,
148    #[serde(untagged)]
149    Other(Cow<'static, str>),
150}
151
152impl FinishReason {
153    pub fn as_str(&self) -> &str {
154        match self {
155            FinishReason::Stop => "stop",
156            FinishReason::Length => "length",
157            FinishReason::ContentFilter => "content_filter",
158            FinishReason::ToolCalls => "tool_calls",
159            FinishReason::Other(reason) => reason.as_ref(),
160        }
161    }
162}
163
164impl std::fmt::Display for FinishReason {
165    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166        write!(f, "{}", self.as_str())
167    }
168}
169
170/// A single message in a chat
171#[derive(Serialize, Deserialize, Default, Debug, Clone)]
172pub struct ChatMessage {
173    /// The role of the message, such as "user" or "assistant".
174    pub role: Option<String>,
175    /// Some providers support this natively. For those that don't, the name
176    /// will be prepended to the message using the format "{name}: {content}".
177    #[serde(skip_serializing_if = "Option::is_none")]
178    pub name: Option<String>,
179    /// Text content of the message
180    pub content: Option<String>,
181    /// A tool call to be invoked
182    #[serde(default, skip_serializing_if = "Vec::is_empty")]
183    pub tool_calls: Vec<ToolCall>,
184    /// A tool call ID when responding to the tool call.
185    #[serde(skip_serializing_if = "Option::is_none")]
186    pub tool_call_id: Option<String>,
187}
188
189impl ChatMessage {
190    /// Merge a chat delta into this message. This replaces most fields, but will concatenate
191    /// message content.
192    pub fn add_delta(&mut self, delta: &ChatMessage) {
193        if self.role.is_none() {
194            self.role = delta.role.clone();
195        }
196        if self.name.is_none() {
197            self.name = delta.name.clone();
198        }
199
200        if self.tool_call_id.is_none() {
201            self.tool_call_id = delta.tool_call_id.clone();
202        }
203
204        match (&mut self.content, &delta.content) {
205            (Some(content), Some(new_content)) => content.push_str(new_content),
206            (None, Some(new_content)) => {
207                self.content = Some(new_content.clone());
208            }
209            _ => {}
210        }
211
212        for tool_call in &delta.tool_calls {
213            let Some(index) = tool_call.index else {
214                // Tool call chunks must always have an index
215                continue;
216            };
217            if self.tool_calls.len() <= index {
218                self.tool_calls.resize(
219                    index + 1,
220                    ToolCall {
221                        index: None,
222                        id: None,
223                        typ: None,
224                        function: ToolCallFunction {
225                            name: None,
226                            arguments: None,
227                        },
228                    },
229                );
230            }
231
232            self.tool_calls[index].merge_delta(tool_call);
233        }
234    }
235}
236
237/// Counts of prompt, completion, and total tokens
238#[derive(Serialize, Deserialize, Default, Debug, Clone)]
239pub struct UsageResponse {
240    /// The number of input tokens
241    pub prompt_tokens: Option<usize>,
242    /// The number of output tokens
243    pub completion_tokens: Option<usize>,
244    /// The sum of the input and output tokens
245    pub total_tokens: Option<usize>,
246}
247
248impl UsageResponse {
249    /// Return true if there is no usage info recorded in this response
250    pub fn is_empty(&self) -> bool {
251        self.prompt_tokens.is_none()
252            && self.completion_tokens.is_none()
253            && self.total_tokens.is_none()
254    }
255
256    /// Merge another UsageResponse into this one. Any fields present in `other` will overwrite
257    /// the current values.
258    pub fn merge(&mut self, other: &UsageResponse) {
259        if other.prompt_tokens.is_some() {
260            self.prompt_tokens = other.prompt_tokens;
261        }
262
263        if other.completion_tokens.is_some() {
264            self.completion_tokens = other.completion_tokens;
265        }
266
267        if other.total_tokens.is_some() {
268            self.total_tokens = other.total_tokens;
269        }
270    }
271}
272
273/// Metadata about the request, from the proxy.
274#[derive(Debug, Clone, Serialize)]
275pub struct RequestInfo {
276    /// A UUID assigned by Chronicle to the request, which is linked to the logged information.
277    pub id: Uuid,
278    /// Which provider was used for the successful request.
279    pub provider: String,
280    /// Which model was used for the request
281    pub model: String,
282    /// How many times we had to retry before we got a successful response.
283    pub num_retries: u32,
284    /// If we retried due to hitting a rate limit.
285    pub was_rate_limited: bool,
286}
287
288/// Metadata about the response, from the provider.
289#[derive(Debug, Clone, Serialize)]
290pub struct ResponseInfo {
291    /// Any other metadata from the provider that should be logged.
292    pub meta: Option<serde_json::Value>,
293    /// The model used for the request, as returned by the provider.
294    pub model: String,
295}
296
297/// Part of a streaming response, returned from the proxy
298#[cfg_attr(test, derive(Serialize))]
299#[derive(Debug, Clone)]
300pub enum StreamingResponse {
301    /// Metadata about the request, from the proxy. This will always be the first message in the
302    /// stream.
303    RequestInfo(RequestInfo),
304    /// A chunk of a streaming response.
305    Chunk(StreamingChatResponse),
306    /// The chat response is completely in this one message. Used for non-streaming requests.
307    Single(SingleChatResponse),
308    /// Metadata about the response, from the provider. This chunk might not be sent.
309    ResponseInfo(ResponseInfo),
310}
311
312/// A channel on which streaming responses can be sent
313pub type StreamingResponseSender = flume::Sender<Result<StreamingResponse, Report<ProviderError>>>;
314/// A channel that can receive streaming responses
315pub type StreamingResponseReceiver =
316    flume::Receiver<Result<StreamingResponse, Report<ProviderError>>>;
317
318/// For providers that conform almost, but not quite, to the OpenAI spec, these transformations
319/// apply small changes that can alter the request in place to the form needed for the provider.
320#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
321#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
322pub struct ChatRequestTransformation<'a> {
323    /// True if the model provider supports a `name` for each message. False if name
324    /// should be merged into the main content of the message when it is provided.
325    pub supports_message_name: bool,
326    /// True if the system message is just another message with a system role.
327    /// False if it is the special `system` field.
328    pub system_in_messages: bool,
329    /// If the model starts with this prefix, strip it off.
330    pub strip_model_prefix: Option<Cow<'a, str>>,
331}
332
333impl<'a> Default for ChatRequestTransformation<'a> {
334    /// The default values match OpenAI's behavior
335    fn default() -> Self {
336        Self {
337            supports_message_name: true,
338            system_in_messages: true,
339            strip_model_prefix: Default::default(),
340        }
341    }
342}
343
344/// The request that can be submitted to the proxy, for transformation and submission to a
345/// provider.
346#[serde_as]
347#[derive(Serialize, Deserialize, Debug, Clone, Default)]
348pub struct ChatRequest {
349    /// The messages in the chat so far.
350    pub messages: Vec<ChatMessage>,
351    /// A separate field for system message as an alternative to specifying it in
352    /// `messages`.
353    #[serde(skip_serializing_if = "Option::is_none")]
354    pub system: Option<String>,
355    /// The model to use. This can be omitted if the proxy options specify a model.
356    #[serde(skip_serializing_if = "Option::is_none")]
357    pub model: Option<String>,
358    /// How to penalize tokens based on their frequency in the text so far
359    #[serde(skip_serializing_if = "Option::is_none")]
360    pub frequency_penalty: Option<f32>,
361    /// Specific control of certain token probabilities
362    #[serde(skip_serializing_if = "Option::is_none")]
363    pub logit_bias: Option<BTreeMap<usize, f32>>,
364    /// Return the logprobs of the generated tokens
365    #[serde(skip_serializing_if = "Option::is_none")]
366    pub logprobs: Option<bool>,
367    /// If `logprobs` is true, how many logprobs to return per token.
368    #[serde(skip_serializing_if = "Option::is_none")]
369    pub top_logprobs: Option<u8>,
370    /// max_tokens is optional for some providers but you should include it.
371    /// We don't require it here for compatibility when wrapping other libraries that may not be aware they
372    /// are using Chronicle.
373    pub max_tokens: Option<u32>,
374    /// Generate multiple chat completions concurrently. Not every model provider supports this.
375    #[serde(skip_serializing_if = "Option::is_none")]
376    pub n: Option<u32>,
377    /// How to penalize tokens based on their existing presence in the text so far
378    #[serde(skip_serializing_if = "Option::is_none")]
379    pub presence_penalty: Option<f32>,
380    /// Force JSON output
381    #[serde(skip_serializing_if = "Option::is_none")]
382    pub response_format: Option<serde_json::Value>,
383    /// A random seed to use when generating the response.
384    #[serde(skip_serializing_if = "Option::is_none")]
385    pub seed: Option<i64>,
386    #[serde(default, skip_serializing_if = "Vec::is_empty")]
387    /// Tell the model to stop when it encounters these token sequences
388    #[serde_as(as = "OneOrMany<_, PreferMany>")]
389    pub stop: Vec<String>,
390    /// Temperature to use when generating the response
391    #[serde(skip_serializing_if = "Option::is_none")]
392    pub temperature: Option<f32>,
393    /// Customize the top-P probability of tokens to consider when generating the response
394    #[serde(skip_serializing_if = "Option::is_none")]
395    pub top_p: Option<f32>,
396    /// Tools available for the model to use
397    #[serde(default, skip_serializing_if = "Vec::is_empty")]
398    pub tools: Vec<Tool>,
399    /// Customize how the model chooses tools
400    #[serde(skip_serializing_if = "Option::is_none")]
401    pub tool_choice: Option<serde_json::Value>,
402    /// The "user" to send to the provider.
403    #[serde(skip_serializing_if = "Option::is_none")]
404    pub user: Option<String>,
405    /// Send the response back as a stream of chunks.
406    #[serde(default)]
407    pub stream: bool,
408    /// For OpenAI, this lets us enable usage when streaming. Chronicle will set this
409    /// automatically when appropriate.
410    #[serde(skip_serializing_if = "Option::is_none")]
411    pub stream_options: Option<StreamOptions>,
412}
413
414/// Stream options for OpenAI. This is automatically set by the proxy when streaming. You can omit
415/// it in your requests.
416#[derive(Serialize, Deserialize, Debug, Clone, Default)]
417pub struct StreamOptions {
418    /// If true, include token usage in the response.
419    pub include_usage: bool,
420}
421
422impl ChatRequest {
423    /// Transform a chat request to fit different variations on the OpenAI format.
424    pub fn transform(&mut self, options: &ChatRequestTransformation) {
425        let stripped = options
426            .strip_model_prefix
427            .as_deref()
428            .zip(self.model.as_deref())
429            .and_then(|(prefix, model)| model.strip_prefix(prefix));
430        if let Some(stripped) = stripped {
431            self.model = Some(stripped.to_string());
432        }
433
434        if !options.supports_message_name {
435            self.merge_message_names();
436        }
437
438        if options.system_in_messages {
439            self.move_system_to_messages();
440        } else {
441            self.move_system_message_to_top_level();
442        }
443    }
444
445    /// For providers that don't support a `name` field in their message,
446    /// convert messages with names to the format "{name}: {content}
447    pub fn merge_message_names(&mut self) {
448        for message in self.messages.iter_mut() {
449            if let Some(name) = message.name.take() {
450                message.content = message.content.as_deref().map(|c| format!("{name}: {c}"));
451            }
452        }
453    }
454
455    /// Move the entry in the `system` field to the start of `messages`.
456    pub fn move_system_to_messages(&mut self) {
457        let system = self.system.take();
458        if let Some(system) = system {
459            self.messages = std::iter::once(ChatMessage {
460                role: Some("system".to_string()),
461                content: Some(system),
462                tool_calls: Vec::new(),
463                name: None,
464                tool_call_id: None,
465            })
466            .chain(self.messages.drain(..))
467            .collect();
468        }
469    }
470
471    /// Move a `system` role [ChatMessage] to the `system` field
472    pub fn move_system_message_to_top_level(&mut self) {
473        if self
474            .messages
475            .get(0)
476            .map(|m| m.role.as_deref().unwrap_or_default() == "system")
477            .unwrap_or(false)
478        {
479            let system = self.messages.remove(0);
480            self.system = system.content;
481        }
482    }
483}
484
485/// Represents a tool that can be used by the OpenAI model
486#[derive(Serialize, Deserialize, Debug, Clone)]
487pub struct Tool {
488    /// The type of the tool, typically "function"
489    #[serde(rename = "type")]
490    pub typ: String,
491    /// The function details of the tool
492    pub function: FunctionTool,
493}
494
495/// Represents the function details of a tool
496#[derive(Serialize, Deserialize, Debug, Clone)]
497pub struct FunctionTool {
498    /// The name of the function
499    pub name: String,
500    /// An optional description of the function
501    pub description: Option<String>,
502    /// Optional parameters for the function, represented as a JSON value
503    pub parameters: Option<serde_json::Value>,
504}
505
506/// Represents a call to a tool by the OpenAI model
507#[derive(Serialize, Deserialize, Debug, Clone)]
508pub struct ToolCall {
509    /// The optional index of the tool call
510    #[serde(skip_serializing_if = "Option::is_none")]
511    pub index: Option<usize>,
512    /// The optional ID of the tool call
513    #[serde(skip_serializing_if = "Option::is_none")]
514    pub id: Option<String>,
515    /// The optional type of the tool call, typically "function"
516    #[serde(skip_serializing_if = "Option::is_none")]
517    #[serde(rename = "type")]
518    pub typ: Option<String>,
519    /// The function details of the tool call
520    pub function: ToolCallFunction,
521}
522
523impl ToolCall {
524    /// Merges a delta ToolCall into this ToolCall, updating fields if they are None
525    fn merge_delta(&mut self, delta: &ToolCall) {
526        if self.index.is_none() {
527            self.index = delta.index;
528        }
529        if self.id.is_none() {
530            self.id = delta.id.clone();
531        }
532        if self.typ.is_none() {
533            self.typ = delta.typ.clone();
534        }
535        if self.function.name.is_none() {
536            self.function.name = delta.function.name.clone();
537        }
538
539        if self.function.arguments.is_none() {
540            self.function.arguments = delta.function.arguments.clone();
541        } else if delta.function.arguments.is_some() {
542            self.function
543                .arguments
544                .as_mut()
545                .unwrap()
546                .push_str(&delta.function.arguments.as_ref().unwrap());
547        }
548    }
549}
550
551/// Represents the function details of a tool call
552#[derive(Serialize, Deserialize, Debug, Clone)]
553pub struct ToolCallFunction {
554    /// The optional name of the function being called
555    #[serde(skip_serializing_if = "Option::is_none")]
556    pub name: Option<String>,
557    /// The optional arguments passed to the function, as a JSON string
558    #[serde(skip_serializing_if = "Option::is_none")]
559    pub arguments: Option<String>,
560}
561
562#[cfg(test)]
563mod tests {
564    use super::FinishReason;
565
566    #[test]
567    fn finish_reason_serialization() {
568        let cases = vec![
569            (FinishReason::Stop, "stop"),
570            (FinishReason::Length, "length"),
571            (FinishReason::ContentFilter, "content_filter"),
572            (FinishReason::ToolCalls, "tool_calls"),
573            (FinishReason::Other("custom_reason".into()), "custom_reason"),
574        ];
575
576        for (finish_reason, expected_str) in cases {
577            let serialized = serde_json::to_value(&finish_reason).unwrap();
578            assert_eq!(serialized, serde_json::json!(expected_str));
579        }
580    }
581
582    #[test]
583    fn finish_reason_deserialization() {
584        let cases = vec![
585            ("stop", FinishReason::Stop),
586            ("length", FinishReason::Length),
587            ("content_filter", FinishReason::ContentFilter),
588            ("tool_calls", FinishReason::ToolCalls),
589            ("custom_reason", FinishReason::Other("custom_reason".into())),
590        ];
591
592        for (json_str, expected_enum) in cases {
593            let deserialized: FinishReason =
594                serde_json::from_value(serde_json::json!(json_str)).unwrap();
595            assert_eq!(deserialized.as_str(), expected_enum.as_str());
596        }
597    }
598}