kind_openai/endpoints/chat/
standard.rs

1use std::collections::HashMap;
2
3use bon::Builder;
4use chat_completion_builder::IsComplete;
5use kind_openai_schema::OpenAISchema;
6use reqwest::Method;
7use serde::{Deserialize, Serialize};
8
9use crate::{endpoints::OpenAIRequestProvider, OpenAIResult, Usage};
10
11use super::{
12    structured::{ChatCompletionRequestResponseFormat, StructuredChatCompletion},
13    FinishReason, Message, Model, UnifiedChatCompletionResponseMessage,
14};
15
16/// A standard chat completion request. The response will be a string in any shape and will not
17/// be parsed.
18#[derive(Serialize, Builder)]
19#[builder(start_fn = model, finish_fn = unstructured, state_mod(vis = "pub"))]
20pub struct ChatCompletion<'a> {
21    #[builder(start_fn)]
22    model: Model,
23    messages: Vec<Message<'a>>,
24    temperature: Option<f32>,
25    top_p: Option<f32>,
26    store: Option<bool>,
27    metadata: Option<HashMap<String, String>>,
28    logit_bias: Option<HashMap<i32, i32>>,
29}
30
31impl OpenAIRequestProvider for ChatCompletion<'_> {
32    type Response = ChatCompletionResponse;
33
34    const METHOD: Method = Method::POST;
35
36    fn path_with_leading_slash() -> String {
37        "/chat/completions".to_string()
38    }
39}
40
41impl super::super::private::Sealed for ChatCompletion<'_> {}
42
43// this is a neat trick where we can take a completed builder and allow it to be "upgraded".
44// because of the `finish_fn` specification, we can either resolve and build immediately with
45// `.unstructured()`, or we can call `.structured()` and provide a schema. doing it this way
46// enables us to nicely represent the `ChatCompletionRequest` without having to specify the
47// generic type.
48impl<'a, S> ChatCompletionBuilder<'a, S>
49where
50    S: IsComplete,
51{
52    /// Upgrades a chat completion request to a structured chat completion request.
53    /// Unless the return type can be inferred, you probably want to call this like so:
54    /// `.structured::<MySchemadType>();`
55    pub fn structured<SS>(self) -> StructuredChatCompletion<'a, SS>
56    where
57        SS: OpenAISchema,
58    {
59        StructuredChatCompletion {
60            base_request: self.unstructured(),
61            response_format: ChatCompletionRequestResponseFormat::JsonSchema(SS::openai_schema()),
62            _phantom: std::marker::PhantomData,
63        }
64    }
65}
66
67/// A response from a chat completion request.
68#[derive(Deserialize)]
69pub struct ChatCompletionResponse {
70    choices: Vec<ChatCompletionResponseChoice>,
71    usage: Usage,
72}
73
74impl ChatCompletionResponse {
75    /// Takes the first message in the response consumes the response.
76    pub fn take_first_choice(self) -> Option<ChatCompletionResponseChoice> {
77        self.choices.into_iter().next()
78    }
79
80    /// Gives the usage tokens of the response.
81    pub fn usage(&self) -> &Usage {
82        &self.usage
83    }
84}
85
86/// A response choice from a chat completion request.
87#[derive(Deserialize)]
88pub struct ChatCompletionResponseChoice {
89    finish_reason: FinishReason,
90    index: i32,
91    message: ChatCompletionResponseMessage,
92}
93
94impl ChatCompletionResponseChoice {
95    /// Takes the message and returns a result that may contain a refusal.
96    pub fn message(self) -> OpenAIResult<String> {
97        Into::<UnifiedChatCompletionResponseMessage<String>>::into(self.message).into()
98    }
99
100    pub fn finish_reason(&self) -> FinishReason {
101        self.finish_reason
102    }
103
104    pub fn index(&self) -> i32 {
105        self.index
106    }
107}
108
109// leave private, messages should only be interacted with through the unified message type.
110#[derive(Deserialize)]
111struct ChatCompletionResponseMessage {
112    content: String,
113    refusal: Option<String>,
114}
115
116impl From<ChatCompletionResponseMessage> for UnifiedChatCompletionResponseMessage<String> {
117    fn from(value: ChatCompletionResponseMessage) -> Self {
118        UnifiedChatCompletionResponseMessage {
119            content: value.content,
120            refusal: value.refusal,
121        }
122    }
123}
124
125#[macro_export]
126macro_rules! logit_bias {
127    () => {
128        std::collections::HashMap::new()
129    };
130
131    ($($key:tt : $value:expr),+ $(,)?) => {{
132        let mut map = std::collections::HashMap::new();
133        $(
134            map.insert($key as i32, $value as i32);
135        )+
136        map
137    }};
138}