Skip to main content

aiway_model_protocol/
shared.rs

1#[cfg(feature = "reqwest")]
2use crate::v1::error::APIError;
3use bytes::Bytes;
4#[cfg(feature = "reqwest")]
5use reqwest::{header::HeaderMap, multipart::Part};
6use serde::{Deserialize, Serialize};
7
8#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
9pub struct Usage {
10    /// Number of tokens in the prompt.
11    #[serde(skip_serializing_if = "Option::is_none")]
12    pub prompt_tokens: Option<u32>,
13    /// Number of tokens in the completion.
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub completion_tokens: Option<u32>,
16    /// Number of tokens in the entire response.
17    pub total_tokens: u32,
18    /// Breakdown of tokens used in the prompt.
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub prompt_tokens_details: Option<PromptTokensDetails>,
21    /// Breakdown of tokens used in a completion.
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub completion_tokens_details: Option<CompletionTokensDetails>,
24}
25
26#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
27pub struct InputTokensDetails {
28    /// The number of tokens that were retrieved from the cache.
29    pub cached_tokens: u32,
30}
31
32#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
33pub struct OutputTokensDetails {
34    /// The number of reasoning tokens.
35    pub reasoning_tokens: u32,
36}
37
38#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
39pub struct PromptTokensDetails {
40    /// Audio input tokens present in the prompt.
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub audio_tokens: Option<u32>,
43    /// Cached tokens present in the prompt.
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub cached_tokens: Option<u32>,
46}
47
48#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
49pub struct CompletionTokensDetails {
50    /// Tokens generated by the model for reasoning.
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub reasoning_tokens: Option<u32>,
53    /// Audio input tokens generated by the model.
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub audio_tokens: Option<u32>,
56    /// When using Predicted Outputs, the number of tokens in the prediction that appeared in the completion.
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub accepted_prediction_tokens: Option<u32>,
59    /// When using Predicted Outputs, the number of tokens in the prediction that did not appear in the completion.
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub rejected_prediction_tokens: Option<u32>,
62}
63
64#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
65pub struct ResponseWrapper<T> {
66    pub data: T,
67    pub headers: Headers,
68}
69
70#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
71pub struct Headers {
72    /// The maximum number of requests that are permitted before exhausting the rate limit.
73    #[serde(rename = "x-ratelimit-limit-requests")]
74    pub x_ratelimit_limit_requests: Option<u32>,
75    /// The maximum number of tokens that are permitted before exhausting the rate limit.
76    #[serde(rename = "x-ratelimit-limit-tokens")]
77    pub x_ratelimit_limit_tokens: Option<u32>,
78    /// The remaining number of requests that are permitted before exhausting the rate limit.
79    #[serde(rename = "x-ratelimit-remaining-requests")]
80    pub x_ratelimit_remaining_requests: Option<u32>,
81    /// The remaining number of tokens that are permitted before exhausting the rate limit.
82    #[serde(rename = "x-ratelimit-remaining-tokens")]
83    pub x_ratelimit_remaining_tokens: Option<u32>,
84    /// The time until the rate limit (based on requests) resets to its initial state.
85    #[serde(rename = "x-ratelimit-reset-requests")]
86    pub x_ratelimit_reset_requests: Option<String>,
87    /// The time until the rate limit (based on tokens) resets to its initial state.
88    #[serde(rename = "x-ratelimit-reset-tokens")]
89    pub x_ratelimit_reset_tokens: Option<String>,
90}
91
92#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
93pub struct SimpleListParameters {
94    /// Identifier for the last object from the previous pagination request.
95    pub after: Option<String>,
96    /// Number of objects to retrieve.
97    pub limit: Option<u32>,
98}
99
100#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
101pub struct ListParameters {
102    /// A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub limit: Option<u32>,
105    /// Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order.
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub order: Option<String>,
108    /// A cursor for use in pagination. after is an object ID that defines your place in the list.
109    #[serde(skip_serializing_if = "Option::is_none")]
110    pub after: Option<String>,
111    /// A cursor for use in pagination. before is an object ID that defines your place in the list.
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub before: Option<String>,
114}
115
116#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
117pub struct ListResponse<T> {
118    // The object type, which is always "list".
119    pub object: String,
120    /// The list ob objects.
121    pub data: Vec<T>,
122    /// The ID of the first objects in the list.
123    pub first_id: Option<String>,
124    /// The ID of the last objects in the list.
125    pub last_id: Option<String>,
126    /// Indicates whether there are more objects to retrieve.
127    pub has_more: bool,
128}
129
130#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
131pub struct DeletedObject {
132    /// ID of the deleted object.
133    pub id: String,
134    /// The object type.
135    pub object: String,
136    /// Indicates whether the file was successfully deleted.
137    pub deleted: bool,
138}
139
140#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
141pub struct LastError {
142    /// One of 'server_error' or 'rate_limit_exceeded'.
143    pub code: LastErrorCode,
144    /// A human-readable description of the error.
145    pub message: String,
146}
147
148#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
149#[serde(rename_all = "snake_case")]
150pub enum LastErrorCode {
151    ServerError,
152    RateLimitExceeded,
153}
154
155#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
156#[serde(rename_all = "snake_case")]
157pub enum FinishReason {
158    /// API returned complete message, or a message terminated by one of the stop sequences provided via the stop parameter.
159    #[serde(rename = "stop", alias = "STOP")]
160    StopSequenceReached,
161    /// Incomplete model output due to max_tokens parameter or token limit.
162    #[serde(rename = "length", alias = "MAX_TOKENS")]
163    TokenLimitReached,
164    /// Omitted content due to a flag from our content filters.
165    #[serde(
166        rename = "content_filter",
167        alias = "SAFETY",
168        alias = "SPII",
169        alias = "PROHIBITED_CONTENT",
170        alias = "BLOCKLIST",
171        alias = "RECITATION"
172    )]
173    ContentFilterFlagged,
174    /// The model decided to call one or more tools.
175    ToolCalls,
176    /// The model reached a natural stopping point. [Claude]
177    EndTurn,
178    /// The finish reason is unspecified. [Gemini]
179    #[serde(rename = "FINISH_REASON_UNSPECIFIED	")]
180    FinishReasonUnspecified,
181    #[serde(rename = "MALFORMED_FUNCTION_CALL")]
182    MalformedFunctionCall,
183    #[serde(rename = "OTHER")]
184    Other,
185}
186
187#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
188#[serde(rename_all = "snake_case")]
189pub enum ReasoningEffort {
190    High,
191    Medium,
192    Low,
193    Minimal,
194}
195
196#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
197#[serde(rename_all = "snake_case")]
198pub enum WebSearchContextSize {
199    Low,
200    Medium,
201    Large,
202}
203
204#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
205#[serde(untagged)]
206pub enum StopToken {
207    String(String),
208    Array(Vec<String>),
209}
210
211#[cfg(feature = "reqwest")]
212impl From<HeaderMap> for Headers {
213    fn from(value: HeaderMap) -> Self {
214        if value.get("x-ratelimit-limit-requests").is_none()
215            || value.get("x-ratelimit-limit-tokens").is_none()
216            || value.get("x-ratelimit-remaining-requests").is_none()
217            || value.get("x-ratelimit-remaining-tokens").is_none()
218            || value.get("x-ratelimit-reset-requests").is_none()
219            || value.get("x-ratelimit-reset-tokens").is_none()
220        {
221            return Self {
222                x_ratelimit_limit_requests: None,
223                x_ratelimit_limit_tokens: None,
224                x_ratelimit_remaining_requests: None,
225                x_ratelimit_remaining_tokens: None,
226                x_ratelimit_reset_requests: None,
227                x_ratelimit_reset_tokens: None,
228            };
229        }
230
231        Self {
232            x_ratelimit_limit_requests: Some(
233                value
234                    .get("x-ratelimit-limit-requests")
235                    .unwrap()
236                    .to_str()
237                    .unwrap()
238                    .parse::<u32>()
239                    .unwrap(),
240            ),
241            x_ratelimit_limit_tokens: Some(
242                value
243                    .get("x-ratelimit-limit-tokens")
244                    .unwrap()
245                    .to_str()
246                    .unwrap()
247                    .parse::<u32>()
248                    .unwrap(),
249            ),
250            x_ratelimit_remaining_requests: Some(
251                value
252                    .get("x-ratelimit-remaining-requests")
253                    .unwrap()
254                    .to_str()
255                    .unwrap()
256                    .parse::<u32>()
257                    .unwrap(),
258            ),
259            x_ratelimit_remaining_tokens: Some(
260                value
261                    .get("x-ratelimit-remaining-tokens")
262                    .unwrap()
263                    .to_str()
264                    .unwrap()
265                    .parse::<u32>()
266                    .unwrap(),
267            ),
268            x_ratelimit_reset_requests: Some(
269                value
270                    .get("x-ratelimit-reset-requests")
271                    .unwrap()
272                    .to_str()
273                    .unwrap()
274                    .to_string(),
275            ),
276            x_ratelimit_reset_tokens: Some(
277                value
278                    .get("x-ratelimit-reset-tokens")
279                    .unwrap()
280                    .to_str()
281                    .unwrap()
282                    .to_string(),
283            ),
284        }
285    }
286}
287
288#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
289pub struct FileUploadBytes {
290    pub bytes: Bytes,
291    pub filename: String,
292}
293impl FileUploadBytes {
294    pub fn new(bytes: impl Into<Bytes>, filename: impl Into<String>) -> Self {
295        Self {
296            bytes: bytes.into(),
297            filename: filename.into(),
298        }
299    }
300
301    #[cfg(feature = "reqwest")]
302    pub(crate) fn into_part(self) -> Result<Part, APIError> {
303        reqwest::multipart::Part::bytes(self.bytes.to_vec())
304            .file_name(self.filename.clone())
305            .mime_str("application/octet-stream")
306            .map_err(|error| APIError::FileError(error.to_string()))
307    }
308}
309
310#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
311pub enum FileUpload {
312    Bytes(FileUploadBytes),
313    BytesArray(Vec<FileUploadBytes>),
314    #[cfg(all(feature = "tokio", feature = "tokio-util"))]
315    File(String),
316    #[cfg(all(feature = "tokio", feature = "tokio-util"))]
317    FileArray(Vec<String>),
318}
319impl FileUpload {
320    #[cfg(feature = "reqwest")]
321    pub(crate) async fn into_part(self) -> Result<Part, APIError> {
322        match self {
323            FileUpload::Bytes(bytes) => bytes.into_part(),
324            FileUpload::BytesArray(_) => {
325                unimplemented!("BytesArray is not supported for this route")
326            }
327            #[cfg(all(feature = "tokio", feature = "tokio-util"))]
328            FileUpload::File(path) => {
329                use tokio::fs::File;
330                use tokio_util::codec::{BytesCodec, FramedRead};
331
332                let file = File::open(&path)
333                    .await
334                    .map_err(|error| APIError::FileError(error.to_string()))?;
335
336                let stream = FramedRead::new(file, BytesCodec::new());
337                let file_body = reqwest::Body::wrap_stream(stream);
338
339                let file_part = reqwest::multipart::Part::stream(file_body).file_name(path);
340                // .mime_str("application/octet-stream")
341                // .unwrap();
342
343                Ok(file_part)
344            }
345            #[cfg(all(feature = "tokio", feature = "tokio-util"))]
346            FileUpload::FileArray(_) => {
347                unimplemented!("FileArray is not supported for this route")
348            }
349        }
350    }
351
352    #[cfg(feature = "reqwest")]
353    pub(crate) async fn into_parts(self) -> Result<Vec<Part>, APIError> {
354        match self {
355            FileUpload::Bytes(bytes) => bytes.into_part().map(|part| vec![part]),
356            FileUpload::BytesArray(bytes) => bytes
357                .into_iter()
358                .map(|bytes| bytes.into_part())
359                .collect::<Result<Vec<Part>, APIError>>(),
360            #[cfg(all(feature = "tokio", feature = "tokio-util"))]
361            FileUpload::File(path) => {
362                use tokio::fs::File;
363                use tokio_util::codec::{BytesCodec, FramedRead};
364
365                let file = File::open(&path)
366                    .await
367                    .map_err(|error| APIError::FileError(error.to_string()))?;
368
369                let stream = FramedRead::new(file, BytesCodec::new());
370                let file_body = reqwest::Body::wrap_stream(stream);
371
372                let file_part = reqwest::multipart::Part::stream(file_body)
373                    .file_name(path)
374                    .mime_str("application/octet-stream")
375                    .unwrap();
376
377                Ok(vec![file_part])
378            }
379            #[cfg(all(feature = "tokio", feature = "tokio-util"))]
380            FileUpload::FileArray(paths) => {
381                use tokio::fs::File;
382                use tokio_util::codec::{BytesCodec, FramedRead};
383
384                let mut file_parts = vec![];
385                for path in paths {
386                    let file = File::open(&path)
387                        .await
388                        .map_err(|error| APIError::FileError(error.to_string()))?;
389
390                    let stream = FramedRead::new(file, BytesCodec::new());
391                    let file_body = reqwest::Body::wrap_stream(stream);
392
393                    let file_part = reqwest::multipart::Part::stream(file_body)
394                        .file_name(path)
395                        .mime_str("application/octet-stream")
396                        .unwrap();
397
398                    file_parts.push(file_part);
399                }
400
401                Ok(file_parts)
402            }
403        }
404    }
405}
406impl Default for FileUpload {
407    fn default() -> Self {
408        Self::Bytes(FileUploadBytes::new(Bytes::new(), ""))
409    }
410}
411
412pub(crate) fn default_created() -> u32 {
413    use std::time::{SystemTime, UNIX_EPOCH};
414    SystemTime::now()
415        .duration_since(UNIX_EPOCH)
416        .unwrap()
417        .as_secs() as u32
418}