1use std::time::Duration;
2
3use crate::error::*;
4use http::{
5 header::{self, HeaderValue},
6 Method,
7};
8use reqwest::Body;
9use smart_default::SmartDefault;
10
11#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, SmartDefault)]
12pub struct GenerationRequest {
13 pub model: String,
14 pub prompt: String,
15 pub size: Option<String>,
16 pub n: Option<i32>,
17 pub response_format: Option<GenerationFormat>,
18 pub seed: Option<i32>,
19 pub steps: Option<i32>,
20 pub cfg_scale: Option<f32>,
21}
22
23impl GenerationRequest {
24 pub fn builder() -> GenerationRequestBuilder {
25 GenerationRequestBuilder::default()
26 }
27
28 pub async fn call(
29 &self,
30 client: &crate::client::Client,
31 timeout: Option<Duration>,
32 ) -> Result<GenerationResponse> {
33 let uri = "images/generations";
34
35 let rep = client
36 .call_impl(
37 Method::POST,
38 uri,
39 vec![(
40 header::CONTENT_TYPE,
41 HeaderValue::from_str("application/json")?,
42 )],
43 Some(Body::from(serde_json::to_string(&self)?)),
44 None,
45 timeout,
46 )
47 .await?;
48
49 let status = rep.status();
50
51 let rep = serde_json::from_slice::<serde_json::Value>(rep.bytes().await?.as_ref())?;
52
53 for l in serde_json::to_string_pretty(&rep)?.lines() {
54 if status.is_client_error() || status.is_server_error() {
55 tracing::error!("REP: {}", l);
56 } else {
57 tracing::trace!("REP: {}", l);
58 }
59 }
60
61 if !status.is_success() {
62 return Err(Error::ApiError(status.as_u16()));
63 }
64
65 Ok(serde_json::from_value(rep)?)
66 }
67}
68
69#[derive(Debug, Clone, SmartDefault)]
70pub struct GenerationRequestBuilder {
71 model: Option<String>,
72 prompt: Option<String>,
73 size: Option<String>,
74 n: Option<i32>,
75 response_format: Option<GenerationFormat>,
76 seed: Option<i32>,
77 steps: Option<i32>,
78 cfg_scale: Option<f32>,
79}
80
81impl GenerationRequestBuilder {
82 pub fn with_model(mut self, model: impl Into<String>) -> Self {
83 self.model = Some(model.into());
84 self
85 }
86
87 pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
88 self.prompt = Some(prompt.into());
89 self
90 }
91
92 pub fn with_size(mut self, width: i32, height: i32) -> Self {
93 self.size = Some(format!("{}x{}", width, height));
94 self
95 }
96
97 pub fn with_n(mut self, n: i32) -> Self {
98 self.n = Some(n);
99 self
100 }
101
102 pub fn with_response_format(mut self, response_format: GenerationFormat) -> Self {
103 self.response_format = Some(response_format);
104 self
105 }
106
107 pub fn with_seed(mut self, seed: i32) -> Self {
108 self.seed = Some(seed);
109 self
110 }
111
112 pub fn with_steps(mut self, steps: i32) -> Self {
113 self.steps = Some(steps);
114 self
115 }
116
117 pub fn with_cfg_scale(mut self, cfg_scale: f32) -> Self {
118 self.cfg_scale = Some(cfg_scale);
119 self
120 }
121
122 pub fn build(self) -> Result<GenerationRequest> {
123 let Self {
124 model,
125 prompt,
126 size,
127 n,
128 response_format,
129 seed,
130 steps,
131 cfg_scale,
132 } = self;
133
134 Ok(GenerationRequest {
135 model: model.ok_or(Error::GenerationRequestBuild)?,
136 prompt: prompt.ok_or(Error::GenerationRequestBuild)?,
137 size,
138 n,
139 response_format,
140 seed,
141 steps,
142 cfg_scale,
143 })
144 }
145}
146
147#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
148pub struct GenerationResponse {
149 pub created: u64,
150 pub data: Vec<GenerationData>,
151}
152
153#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
154pub struct GenerationData {
155 pub seed: i32,
156 pub finish_reason: String,
157 pub image: Option<String>,
158 pub url: Option<String>,
159}
160
161#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
162#[allow(non_camel_case_types)]
163pub enum GenerationFormat {
164 b64_json,
165 url,
166}
167
168#[cfg(test)]
169#[tokio::test]
170async fn test_genai_ok() -> Result<()> {
171 use crate::client::Client;
172
173 let client = Client::from_env_file(".env.stepfun.genai")?;
174 let _ = tracing_subscriber::fmt::try_init();
175
176 let model_name = std::env::var("OPENAI_API_MODEL_NAME")?;
177
178 let res = GenerationRequest::builder()
179 .with_prompt("Sweet and Sour Mandarin Fish, a chinese traitional dish.")
180 .with_model(model_name)
181 .build()?
182 .call(&client, None)
183 .await?;
184
185 for data in serde_json::to_string_pretty(&res)?.lines() {
186 tracing::info!("REP: {}", data);
187 }
188
189 Ok(())
190}