gpt_batch_scribe/
gpt_batch_api_request.rs1crate::ix!();
18
19#[derive(Debug, Serialize, Deserialize)]
21pub struct GptBatchAPIRequest {
22
23 custom_id: CustomRequestId,
25
26 #[serde(with = "http_method")]
28 method: HttpMethod,
29
30 #[serde(with = "api_url")]
32 url: GptApiUrl,
33
34 body: GptRequestBody,
36}
37
38impl GptBatchAPIRequest {
39 pub fn custom_id(&self) -> &CustomRequestId {
40 &self.custom_id
41 }
42}
43
44impl From<GptBatchAPIRequest> for BatchRequestInput {
45
46 fn from(request: GptBatchAPIRequest) -> Self {
47 BatchRequestInput {
48 custom_id: request.custom_id.to_string(),
49 method: BatchRequestInputMethod::POST,
50 url: match request.url {
51 GptApiUrl::ChatCompletions => BatchEndpoint::V1ChatCompletions,
52 },
53 body: Some(serde_json::to_value(&request.body).unwrap()),
54 }
55 }
56}
57
58pub fn create_batch_input_file(
59 requests: &[GptBatchAPIRequest],
60 batch_input_filename: impl AsRef<Path>,
61
62) -> Result<(), BatchInputCreationError> {
63
64 use std::io::{BufWriter,Write};
65 use std::fs::File;
66
67 let file = File::create(batch_input_filename.as_ref())?;
68 let mut writer = BufWriter::new(file);
69
70 for request in requests {
71 let batch_input = BatchRequestInput {
72 custom_id: request.custom_id.to_string(),
73 method: match request.method {
74 HttpMethod::Post => BatchRequestInputMethod::POST,
75 _ => unimplemented!("Only POST method is supported"),
76 },
77 url: match request.url {
78 GptApiUrl::ChatCompletions => BatchEndpoint::V1ChatCompletions,
79 },
81 body: Some(serde_json::to_value(&request.body)?),
82 };
83 let line = serde_json::to_string(&batch_input)?;
84 writeln!(writer, "{}", line)?;
85 }
86
87 Ok(())
88}
89
90impl GptBatchAPIRequest {
91
92 pub fn requests_from_query_strings(system_message: &str, model: GptModelType, queries: &[String]) -> Vec<Self> {
93 queries.iter().enumerate().map(|(idx,query)| Self::new_basic(model,idx,system_message,&query)).collect()
94 }
95
96 pub fn new_basic(model: GptModelType, idx: usize, system_message: &str, user_message: &str) -> Self {
97 Self {
98 custom_id: Self::custom_id_for_idx(idx),
99 method: HttpMethod::Post,
100 url: GptApiUrl::ChatCompletions,
101 body: GptRequestBody::new_basic(model,system_message,user_message),
102 }
103 }
104
105 pub fn new_with_image(model: GptModelType, idx: usize, system_message: &str, user_message: &str, image_b64: &str) -> Self {
106 Self {
107 custom_id: Self::custom_id_for_idx(idx),
108 method: HttpMethod::Post,
109 url: GptApiUrl::ChatCompletions,
110 body: GptRequestBody::new_with_image(model,system_message,user_message,image_b64),
111 }
112 }
113}
114
115impl Display for GptBatchAPIRequest {
116 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
117 match serde_json::to_string(self) {
118 Ok(json) => write!(f, "{}", json),
119 Err(e) => {
120 write!(f, "Error serializing to JSON: {}", e)
122 }
123 }
124 }
125}
126
127impl GptBatchAPIRequest {
128
129 pub(crate) fn custom_id_for_idx(idx: usize) -> CustomRequestId {
130 CustomRequestId::new(format!("request-{}",idx))
131 }
132}
133
134mod http_method {
136
137 use super::*;
138
139 pub fn serialize<S>(value: &HttpMethod, serializer: S) -> Result<S::Ok, S::Error>
140 where
141 S: Serializer,
142 {
143 serializer.serialize_str(&value.to_string())
144 }
145
146 pub fn deserialize<'de, D>(deserializer: D) -> Result<HttpMethod, D::Error>
147 where
148 D: Deserializer<'de>,
149 {
150 let s: String = Deserialize::deserialize(deserializer)?;
151 match s.as_ref() {
152 "POST" => Ok(HttpMethod::Post),
153 _ => Err(serde::de::Error::custom("unknown method")),
154 }
155 }
156}
157
158mod api_url {
159
160 use super::*;
161
162 pub fn serialize<S>(value: &GptApiUrl, serializer: S) -> Result<S::Ok, S::Error>
163 where
164 S: Serializer,
165 {
166 serializer.serialize_str(&value.to_string())
167 }
168
169 pub fn deserialize<'de, D>(deserializer: D) -> Result<GptApiUrl, D::Error>
170 where
171 D: Deserializer<'de>,
172 {
173 let s: String = Deserialize::deserialize(deserializer)?;
174 match s.as_ref() {
175 "/v1/chat/completions" => Ok(GptApiUrl::ChatCompletions),
176 _ => Err(serde::de::Error::custom("unknown URL")),
177 }
178 }
179}