gpt_batch_scribe/
gpt_batch_api_request.rs

1/*!
2  | basic example:
3  | {
4  |     "custom_id": "request-1", 
5  |     "method": "POST", 
6  |     "url": "/v1/chat/completions", 
7  |     "body": {
8  |         "model": "gpt-4", 
9  |         "messages": [
10  |             {"role": "system", "content": "You are a helpful assistant."},
11  |             {"role": "user", "content": "Hello world!"}
12  |         ],
13  |         "max_tokens": 1000
14  |     }
15  | }
16  */
17crate::ix!();
18
19/// Represents the complete request structure.
20#[derive(Debug, Serialize, Deserialize)]
21pub struct GptBatchAPIRequest {
22
23    /// Identifier for the custom request.
24    custom_id: CustomRequestId,
25
26    /// HTTP method used for the request.
27    #[serde(with = "http_method")]
28    method: HttpMethod,
29
30    /// URL of the API endpoint.
31    #[serde(with = "api_url")]
32    url:  GptApiUrl,
33
34    /// Body of the request.
35    body: GptRequestBody,
36}
37
38impl GptBatchAPIRequest {
39    pub fn custom_id(&self) -> &CustomRequestId {
40        &self.custom_id
41    }
42}
43
44impl From<GptBatchAPIRequest> for BatchRequestInput {
45
46    fn from(request: GptBatchAPIRequest) -> Self {
47        BatchRequestInput {
48            custom_id: request.custom_id.to_string(),
49            method: BatchRequestInputMethod::POST,
50            url: match request.url {
51                GptApiUrl::ChatCompletions => BatchEndpoint::V1ChatCompletions,
52            },
53            body: Some(serde_json::to_value(&request.body).unwrap()),
54        }
55    }
56}
57
58pub fn create_batch_input_file(
59    requests:             &[GptBatchAPIRequest],
60    batch_input_filename: impl AsRef<Path>,
61
62) -> Result<(), BatchInputCreationError> {
63
64    use std::io::{BufWriter,Write};
65    use std::fs::File;
66
67    let file = File::create(batch_input_filename.as_ref())?;
68    let mut writer = BufWriter::new(file);
69
70    for request in requests {
71        let batch_input = BatchRequestInput {
72            custom_id: request.custom_id.to_string(),
73            method: match request.method {
74                HttpMethod::Post => BatchRequestInputMethod::POST,
75                _ => unimplemented!("Only POST method is supported"),
76            },
77            url: match request.url {
78                GptApiUrl::ChatCompletions => BatchEndpoint::V1ChatCompletions,
79                // Handle other endpoints if necessary
80            },
81            body: Some(serde_json::to_value(&request.body)?),
82        };
83        let line = serde_json::to_string(&batch_input)?;
84        writeln!(writer, "{}", line)?;
85    }
86
87    Ok(())
88}
89
90impl GptBatchAPIRequest {
91
92    pub fn requests_from_query_strings(system_message: &str, model: GptModelType, queries: &[String]) -> Vec<Self> {
93        queries.iter().enumerate().map(|(idx,query)| Self::new_basic(model,idx,system_message,&query)).collect()
94    }
95
96    pub fn new_basic(model: GptModelType, idx: usize, system_message: &str, user_message: &str) -> Self {
97        Self {
98            custom_id: Self::custom_id_for_idx(idx),
99            method:    HttpMethod::Post,
100            url:       GptApiUrl::ChatCompletions,
101            body:      GptRequestBody::new_basic(model,system_message,user_message),
102        }
103    }
104
105    pub fn new_with_image(model: GptModelType, idx: usize, system_message: &str, user_message: &str, image_b64: &str) -> Self {
106        Self {
107            custom_id: Self::custom_id_for_idx(idx),
108            method:    HttpMethod::Post,
109            url:       GptApiUrl::ChatCompletions,
110            body:      GptRequestBody::new_with_image(model,system_message,user_message,image_b64),
111        }
112    }
113}
114
115impl Display for GptBatchAPIRequest {
116    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
117        match serde_json::to_string(self) {
118            Ok(json) => write!(f, "{}", json),
119            Err(e) => {
120                // Handle JSON serialization errors, though they shouldn't occur with proper struct definitions
121                write!(f, "Error serializing to JSON: {}", e)
122            }
123        }
124    }
125}
126
127impl GptBatchAPIRequest {
128
129    pub(crate) fn custom_id_for_idx(idx: usize) -> CustomRequestId {
130        CustomRequestId::new(format!("request-{}",idx))
131    }
132}
133
134/// Custom serialization modules for enum string representations.
135mod http_method {
136
137    use super::*;
138
139    pub fn serialize<S>(value: &HttpMethod, serializer: S) -> Result<S::Ok, S::Error>
140    where
141        S: Serializer,
142    {
143        serializer.serialize_str(&value.to_string())
144    }
145
146    pub fn deserialize<'de, D>(deserializer: D) -> Result<HttpMethod, D::Error>
147    where
148        D: Deserializer<'de>,
149    {
150        let s: String = Deserialize::deserialize(deserializer)?;
151        match s.as_ref() {
152            "POST" => Ok(HttpMethod::Post),
153            _ => Err(serde::de::Error::custom("unknown method")),
154        }
155    }
156}
157
158mod api_url {
159
160    use super::*;
161
162    pub fn serialize<S>(value: &GptApiUrl, serializer: S) -> Result<S::Ok, S::Error>
163    where
164        S: Serializer,
165    {
166        serializer.serialize_str(&value.to_string())
167    }
168
169    pub fn deserialize<'de, D>(deserializer: D) -> Result<GptApiUrl, D::Error>
170    where
171        D: Deserializer<'de>,
172    {
173        let s: String = Deserialize::deserialize(deserializer)?;
174        match s.as_ref() {
175            "/v1/chat/completions" => Ok(GptApiUrl::ChatCompletions),
176            _ => Err(serde::de::Error::custom("unknown URL")),
177        }
178    }
179}