openai_ng/proto/
image.rs

1use std::time::Duration;
2
3use crate::error::*;
4use http::{
5    header::{self, HeaderValue},
6    Method,
7};
8use reqwest::Body;
9use smart_default::SmartDefault;
10
11#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, SmartDefault)]
12pub struct GenerationRequest {
13    pub model: String,
14    pub prompt: String,
15    pub size: Option<String>,
16    pub n: Option<i32>,
17    pub response_format: Option<GenerationFormat>,
18    pub seed: Option<i32>,
19    pub steps: Option<i32>,
20    pub cfg_scale: Option<f32>,
21}
22
23impl GenerationRequest {
24    pub fn builder() -> GenerationRequestBuilder {
25        GenerationRequestBuilder::default()
26    }
27
28    pub async fn call(
29        &self,
30        client: &crate::client::Client,
31        timeout: Option<Duration>,
32    ) -> Result<GenerationResponse> {
33        let uri = "images/generations";
34
35        let rep = client
36            .call_impl(
37                Method::POST,
38                uri,
39                vec![(
40                    header::CONTENT_TYPE,
41                    HeaderValue::from_str("application/json")?,
42                )],
43                Some(Body::from(serde_json::to_string(&self)?)),
44                None,
45                timeout,
46            )
47            .await?;
48
49        let status = rep.status();
50
51        let rep = serde_json::from_slice::<serde_json::Value>(rep.bytes().await?.as_ref())?;
52
53        for l in serde_json::to_string_pretty(&rep)?.lines() {
54            if status.is_client_error() || status.is_server_error() {
55                tracing::error!("REP: {}", l);
56            } else {
57                tracing::trace!("REP: {}", l);
58            }
59        }
60
61        if !status.is_success() {
62            return Err(Error::ApiError(status.as_u16()));
63        }
64
65        Ok(serde_json::from_value(rep)?)
66    }
67}
68
69#[derive(Debug, Clone, SmartDefault)]
70pub struct GenerationRequestBuilder {
71    model: Option<String>,
72    prompt: Option<String>,
73    size: Option<String>,
74    n: Option<i32>,
75    response_format: Option<GenerationFormat>,
76    seed: Option<i32>,
77    steps: Option<i32>,
78    cfg_scale: Option<f32>,
79}
80
81impl GenerationRequestBuilder {
82    pub fn with_model(mut self, model: impl Into<String>) -> Self {
83        self.model = Some(model.into());
84        self
85    }
86
87    pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
88        self.prompt = Some(prompt.into());
89        self
90    }
91
92    pub fn with_size(mut self, width: i32, height: i32) -> Self {
93        self.size = Some(format!("{}x{}", width, height));
94        self
95    }
96
97    pub fn with_n(mut self, n: i32) -> Self {
98        self.n = Some(n);
99        self
100    }
101
102    pub fn with_response_format(mut self, response_format: GenerationFormat) -> Self {
103        self.response_format = Some(response_format);
104        self
105    }
106
107    pub fn with_seed(mut self, seed: i32) -> Self {
108        self.seed = Some(seed);
109        self
110    }
111
112    pub fn with_steps(mut self, steps: i32) -> Self {
113        self.steps = Some(steps);
114        self
115    }
116
117    pub fn with_cfg_scale(mut self, cfg_scale: f32) -> Self {
118        self.cfg_scale = Some(cfg_scale);
119        self
120    }
121
122    pub fn build(self) -> Result<GenerationRequest> {
123        let Self {
124            model,
125            prompt,
126            size,
127            n,
128            response_format,
129            seed,
130            steps,
131            cfg_scale,
132        } = self;
133
134        Ok(GenerationRequest {
135            model: model.ok_or(Error::GenerationRequestBuild)?,
136            prompt: prompt.ok_or(Error::GenerationRequestBuild)?,
137            size,
138            n,
139            response_format,
140            seed,
141            steps,
142            cfg_scale,
143        })
144    }
145}
146
147#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
148pub struct GenerationResponse {
149    pub created: u64,
150    pub data: Vec<GenerationData>,
151}
152
153#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
154pub struct GenerationData {
155    pub seed: i32,
156    pub finish_reason: String,
157    pub image: Option<String>,
158    pub url: Option<String>,
159}
160
161#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
162#[allow(non_camel_case_types)]
163pub enum GenerationFormat {
164    b64_json,
165    url,
166}
167
168#[cfg(test)]
169#[tokio::test]
170async fn test_genai_ok() -> Result<()> {
171    use crate::client::Client;
172
173    let client = Client::from_env_file(".env.stepfun.genai")?;
174    let _ = tracing_subscriber::fmt::try_init();
175
176    let model_name = std::env::var("OPENAI_API_MODEL_NAME")?;
177
178    let res = GenerationRequest::builder()
179        .with_prompt("Sweet and Sour Mandarin Fish, a chinese traitional dish.")
180        .with_model(model_name)
181        .build()?
182        .call(&client, None)
183        .await?;
184
185    for data in serde_json::to_string_pretty(&res)?.lines() {
186        tracing::info!("REP: {}", data);
187    }
188
189    Ok(())
190}