openai_rs/endpoints/
mod.rs

1use std::fmt::{Debug, Display, Formatter};
2use std::io::Error;
3
4pub mod completion;
5pub mod classification;
6pub mod answer;
7pub mod search;
8pub mod edits;
9
10use serde::{Deserialize, Serialize};
11
12/// This request-Module is for internal purpose
13pub(crate) mod request {
14    use hyper::{Body, Request};
15    use serde::Serialize;
16
17    macro_rules! post {
18        ($endpoint:ident, $auth_token:ident, $serialized:ident) => {{
19            hyper::http::Request::builder()
20                .method(hyper::http::method::Method::POST)
21                .uri($endpoint)
22                .header(hyper::header::AUTHORIZATION, &format!("Bearer {}", $auth_token))
23                .header(hyper::header::CONTENT_TYPE, "application/json")
24                .body(hyper::body::Body::from($serialized)).expect("Failed to build request")
25        }}
26    }
27    pub(super) use post;
28
29    /// An Endpoint-Trait which contains the ability to form a request.
30    /// This trait is mainly used for internal purpose (implementation of the Endpoint-Trait)
31    pub trait Endpoint
32    where Self: Serialize {
33        const ENDPOINT: &'static str;
34
35        fn request(
36            &self,
37            auth_token: &str,
38            engine_id: Option<&str>
39        ) -> Request<Body>;
40    }
41}
42
43#[derive(Debug, Clone, Deserialize)]
44pub struct Response {
45    pub id: Option<String>,
46    pub object: Option<String>,
47    pub created: Option<u64>,
48    pub model: Option<String>,
49    pub choices: Option<Vec<Choice>>,
50    pub data: Option<Vec<Data>>,
51    pub completion: Option<String>,
52    pub label: Option<String>,
53    pub search_model: Option<Model>,
54    pub selected_examples: Option<Vec<SelectedExample>>,
55    pub selected_documents: Option<Vec<SelectedDocument>>
56}
57
58#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
59pub struct Choice {
60    pub text: String,
61    pub index: usize,
62    pub logprobs: Option<u32>,
63    pub finish_reason: Option<String>
64}
65
66#[derive(Debug, Clone, PartialEq, Deserialize)]
67pub struct Data {
68    pub document: u32,
69    pub object: String,
70    pub score: f32,
71}
72
73#[derive(Debug, Clone, PartialEq, Deserialize)]
74pub struct SelectedExample {
75    pub document: u32,
76    pub label: String,
77    pub text: String
78}
79
80#[derive(Debug, Clone, PartialEq, Deserialize)]
81pub struct SelectedDocument {
82    pub document: u32,
83    pub text: String
84}
85
86#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
87#[serde(rename_all = "lowercase")]
88pub enum Model {
89    Ada,
90    Babbage,
91    Curie,
92    Davinci
93}
94
95#[derive(Debug)]
96pub enum ResponseError {
97    Io(Error),
98    Hyper(hyper::Error),
99    ErrorCode(hyper::StatusCode),
100    Serialization(serde_json::Error),
101}
102
103impl Default for Model {
104    fn default() -> Self {
105        Self::Ada
106    }
107}
108
109impl Display for ResponseError {
110    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
111        match self {
112            ResponseError::Io(error) => write!(f, "IO error: {}", error),
113            ResponseError::Hyper(error) => write!(f, "Hyper error: {}", error),
114            ResponseError::ErrorCode(status) => write!(f, "Error code: {}", status),
115            ResponseError::Serialization(error) => write!(f, "Serialization error: {}", error),
116        }
117    }
118}
119
120impl From<serde_json::Error> for ResponseError {
121    fn from(error: serde_json::Error) -> Self {
122        Self::Serialization(error)
123    }
124}
125
126impl From<Error> for ResponseError {
127    fn from(error: Error) -> Self {
128        Self::Io(error)
129    }
130}
131
132impl From<hyper::Error> for ResponseError {
133    fn from(error: hyper::Error) -> Self {
134        Self::Hyper(error)
135    }
136}
137
138impl std::error::Error for ResponseError {}