1use serde::{Deserialize, Serialize};
23use std::env;
24use std::thread;
25use std::time::{Duration, Instant};
26
27pub const DEFAULT_BASE_URL: &str = "https://omniflash.net/api/v1";
29
30#[derive(Debug, Clone, Deserialize, Serialize, Default)]
32pub struct Task {
33 pub task_id: String,
34 #[serde(default = "default_status")]
35 pub task_status: u8,
36 #[serde(default)]
37 pub image_url: Option<String>,
38 #[serde(default)]
39 pub video_url: Option<String>,
40 #[serde(default)]
41 pub audio_url: Option<String>,
42 #[serde(default)]
43 pub request_id: Option<String>,
44 #[serde(default)]
45 pub task_type: Option<String>,
46 #[serde(default)]
47 pub model_id: Option<String>,
48 #[serde(default)]
49 pub credits: Option<i64>,
50 #[serde(default)]
51 pub created_at: Option<i64>,
52 #[serde(default)]
53 pub msg: Option<String>,
54}
55
56fn default_status() -> u8 {
57 Task::STATUS_QUEUED
58}
59
60impl Task {
61 pub const STATUS_QUEUED: u8 = 1;
62 pub const STATUS_RUNNING: u8 = 2;
63 pub const STATUS_SUCCESS: u8 = 3;
64 pub const STATUS_FAILED: u8 = 4;
65
66 pub fn is_done(&self) -> bool {
68 self.task_status == Self::STATUS_SUCCESS || self.task_status == Self::STATUS_FAILED
69 }
70
71 pub fn output_url(&self) -> Option<&str> {
73 self.video_url
74 .as_deref()
75 .or(self.image_url.as_deref())
76 .or(self.audio_url.as_deref())
77 }
78}
79
80#[derive(Debug, Clone, Default, Serialize)]
82pub struct CreateTaskInput {
83 pub model_id: String,
84 pub prompt: String,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 pub image_urls: Option<Vec<String>>,
87 #[serde(skip_serializing_if = "Option::is_none")]
88 pub aspect_ratio: Option<String>,
89}
90
91#[derive(Debug, Clone)]
93pub struct RunOptions {
94 pub poll_interval: Duration,
95 pub max_wait: Duration,
96}
97
98impl Default for RunOptions {
99 fn default() -> Self {
100 Self {
101 poll_interval: Duration::from_secs(3),
102 max_wait: Duration::from_secs(600),
103 }
104 }
105}
106
107#[derive(Debug)]
109pub enum Error {
110 MissingApiKey,
111 Transport(String),
112 Api {
113 code: Option<i64>,
114 status: Option<u16>,
115 message: String,
116 },
117}
118
119impl std::fmt::Display for Error {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 match self {
122 Error::MissingApiKey => write!(
123 f,
124 "missing API key — pass via builder or set OMNIFLASH_API_KEY. Get one at https://omniflash.net/"
125 ),
126 Error::Transport(m) => write!(f, "transport error: {m}"),
127 Error::Api { code, status, message } => {
128 write!(f, "api error")?;
129 if let Some(s) = status {
130 write!(f, " status={s}")?;
131 }
132 if let Some(c) = code {
133 write!(f, " code={c}")?;
134 }
135 write!(f, ": {message}")
136 }
137 }
138 }
139}
140
141impl std::error::Error for Error {}
142
143pub type Result<T> = std::result::Result<T, Error>;
144
145pub struct OmniFlash {
147 api_key: String,
148 base_url: String,
149 agent: ureq::Agent,
150}
151
152impl OmniFlash {
153 pub fn new(api_key: Option<&str>) -> Result<Self> {
155 let key = api_key
156 .map(str::to_owned)
157 .or_else(|| env::var("OMNIFLASH_API_KEY").ok())
158 .filter(|s| !s.is_empty())
159 .ok_or(Error::MissingApiKey)?;
160 Ok(Self {
161 api_key: key,
162 base_url: DEFAULT_BASE_URL.to_owned(),
163 agent: ureq::AgentBuilder::new()
164 .timeout(Duration::from_secs(60))
165 .build(),
166 })
167 }
168
169 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
171 self.base_url = base_url.into().trim_end_matches('/').to_owned();
172 self
173 }
174
175 pub fn create_task(&self, input: &CreateTaskInput) -> Result<Task> {
177 let value = self.request("POST", "/tasks/create", Some(serde_json::to_value(input).unwrap()))?;
178 let task: Task = serde_json::from_value(value)
179 .map_err(|e| Error::Transport(format!("invalid task payload: {e}")))?;
180 Ok(task)
181 }
182
183 pub fn get_task(&self, task_id: &str) -> Result<Task> {
185 let value = self.request("GET", &format!("/tasks/{}", task_id), None)?;
186 serde_json::from_value(value).map_err(|e| Error::Transport(format!("invalid task payload: {e}")))
187 }
188
189 pub fn run(&self, input: CreateTaskInput, options: Option<RunOptions>) -> Result<Task> {
191 let opts = options.unwrap_or_default();
192 let mut task = self.create_task(&input)?;
193 let deadline = Instant::now() + opts.max_wait;
194 while !task.is_done() {
195 if Instant::now() > deadline {
196 return Err(Error::Api {
197 code: None,
198 status: None,
199 message: format!(
200 "task {} did not finish within {:?}",
201 task.task_id, opts.max_wait
202 ),
203 });
204 }
205 thread::sleep(opts.poll_interval);
206 task = self.get_task(&task.task_id)?;
207 }
208 if task.task_status == Task::STATUS_FAILED {
209 return Err(Error::Api {
210 code: Some(Task::STATUS_FAILED as i64),
211 status: None,
212 message: task
213 .msg
214 .clone()
215 .unwrap_or_else(|| format!("task {} failed", task.task_id)),
216 });
217 }
218 Ok(task)
219 }
220
221 fn request(
222 &self,
223 method: &str,
224 path: &str,
225 body: Option<serde_json::Value>,
226 ) -> Result<serde_json::Value> {
227 let url = format!("{}{}", self.base_url, path);
228 let req = self
229 .agent
230 .request(method, &url)
231 .set("Authorization", &format!("Bearer {}", self.api_key))
232 .set("Accept", "application/json");
233 let response = match body {
234 Some(b) => req.send_json(b),
235 None => req.call(),
236 };
237 match response {
238 Ok(r) => {
239 let envelope: serde_json::Value = r
240 .into_json()
241 .map_err(|e| Error::Transport(format!("invalid JSON: {e}")))?;
242 if let Some(code) = envelope.get("code").and_then(|v| v.as_i64()) {
243 if code != 200 {
244 let message = envelope
245 .get("msg")
246 .and_then(|v| v.as_str())
247 .unwrap_or("business error")
248 .to_owned();
249 return Err(Error::Api {
250 code: Some(code),
251 status: None,
252 message,
253 });
254 }
255 }
256 if let Some(data) = envelope.get("data").cloned() {
258 if !data.is_null() {
259 return Ok(data);
260 }
261 }
262 Ok(envelope)
263 }
264 Err(ureq::Error::Status(status, resp)) => {
265 if status == 401 {
266 return Err(Error::Api {
267 code: None,
268 status: Some(401),
269 message: "unauthorized — check your OMNIFLASH_API_KEY (https://omniflash.net/)".to_owned(),
270 });
271 }
272 let body = resp.into_string().unwrap_or_default();
273 let msg = serde_json::from_str::<serde_json::Value>(&body)
274 .ok()
275 .and_then(|v| v.get("msg").and_then(|m| m.as_str().map(str::to_owned)))
276 .unwrap_or_else(|| format!("HTTP {status}"));
277 Err(Error::Api {
278 code: None,
279 status: Some(status),
280 message: msg,
281 })
282 }
283 Err(e) => Err(Error::Transport(e.to_string())),
284 }
285 }
286}