Skip to main content

omniflash_sdk/
lib.rs

1//! Omni Flash Rust SDK — generate short videos (with synchronized audio) and images
2//! using Google's **Gemini Omni Flash** family of models.
3//!
4//! Sign in at <https://omniflash.net/> and create an API key on the account page.
5//! Export it as `OMNIFLASH_API_KEY` and the client picks it up automatically.
6//!
7//! # Quick start
8//!
9//! ```no_run
10//! use omniflash_sdk::{OmniFlash, CreateTaskInput};
11//!
12//! let client = OmniFlash::new(None).unwrap();
13//! let task = client.run(CreateTaskInput {
14//!     model_id: "seedance-2".into(),
15//!     prompt: "a kettle whistles as steam rises".into(),
16//!     aspect_ratio: Some("16:9".into()),
17//!     ..Default::default()
18//! }, None).unwrap();
19//! println!("{:?} {:?}", task.video_url, task.audio_url);
20//! ```
21
22use serde::{Deserialize, Serialize};
23use std::env;
24use std::thread;
25use std::time::{Duration, Instant};
26
27/// Default base URL for the Omni Flash API.
28pub const DEFAULT_BASE_URL: &str = "https://omniflash.net/api/v1";
29
30/// State of a generation job.
31#[derive(Debug, Clone, Deserialize, Serialize, Default)]
32pub struct Task {
33    pub task_id: String,
34    #[serde(default = "default_status")]
35    pub task_status: u8,
36    #[serde(default)]
37    pub image_url: Option<String>,
38    #[serde(default)]
39    pub video_url: Option<String>,
40    #[serde(default)]
41    pub audio_url: Option<String>,
42    #[serde(default)]
43    pub request_id: Option<String>,
44    #[serde(default)]
45    pub task_type: Option<String>,
46    #[serde(default)]
47    pub model_id: Option<String>,
48    #[serde(default)]
49    pub credits: Option<i64>,
50    #[serde(default)]
51    pub created_at: Option<i64>,
52    #[serde(default)]
53    pub msg: Option<String>,
54}
55
56fn default_status() -> u8 {
57    Task::STATUS_QUEUED
58}
59
60impl Task {
61    pub const STATUS_QUEUED: u8 = 1;
62    pub const STATUS_RUNNING: u8 = 2;
63    pub const STATUS_SUCCESS: u8 = 3;
64    pub const STATUS_FAILED: u8 = 4;
65
66    /// Whether the task has reached a terminal state.
67    pub fn is_done(&self) -> bool {
68        self.task_status == Self::STATUS_SUCCESS || self.task_status == Self::STATUS_FAILED
69    }
70
71    /// Returns `video_url` if present, else `image_url`, else `audio_url`.
72    pub fn output_url(&self) -> Option<&str> {
73        self.video_url
74            .as_deref()
75            .or(self.image_url.as_deref())
76            .or(self.audio_url.as_deref())
77    }
78}
79
80/// Request body for `POST /tasks/create`.
81#[derive(Debug, Clone, Default, Serialize)]
82pub struct CreateTaskInput {
83    pub model_id: String,
84    pub prompt: String,
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub image_urls: Option<Vec<String>>,
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub aspect_ratio: Option<String>,
89}
90
91/// Polling configuration for `OmniFlash::run`.
92#[derive(Debug, Clone)]
93pub struct RunOptions {
94    pub poll_interval: Duration,
95    pub max_wait: Duration,
96}
97
98impl Default for RunOptions {
99    fn default() -> Self {
100        Self {
101            poll_interval: Duration::from_secs(3),
102            max_wait: Duration::from_secs(600),
103        }
104    }
105}
106
107/// SDK error type.
108#[derive(Debug)]
109pub enum Error {
110    MissingApiKey,
111    Transport(String),
112    Api {
113        code: Option<i64>,
114        status: Option<u16>,
115        message: String,
116    },
117}
118
119impl std::fmt::Display for Error {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        match self {
122            Error::MissingApiKey => write!(
123                f,
124                "missing API key — pass via builder or set OMNIFLASH_API_KEY. Get one at https://omniflash.net/"
125            ),
126            Error::Transport(m) => write!(f, "transport error: {m}"),
127            Error::Api { code, status, message } => {
128                write!(f, "api error")?;
129                if let Some(s) = status {
130                    write!(f, " status={s}")?;
131                }
132                if let Some(c) = code {
133                    write!(f, " code={c}")?;
134                }
135                write!(f, ": {message}")
136            }
137        }
138    }
139}
140
141impl std::error::Error for Error {}
142
143pub type Result<T> = std::result::Result<T, Error>;
144
145/// Client for the Omni Flash API.
146pub struct OmniFlash {
147    api_key: String,
148    base_url: String,
149    agent: ureq::Agent,
150}
151
152impl OmniFlash {
153    /// Construct a new client. If `api_key` is `None`, reads `OMNIFLASH_API_KEY` from the environment.
154    pub fn new(api_key: Option<&str>) -> Result<Self> {
155        let key = api_key
156            .map(str::to_owned)
157            .or_else(|| env::var("OMNIFLASH_API_KEY").ok())
158            .filter(|s| !s.is_empty())
159            .ok_or(Error::MissingApiKey)?;
160        Ok(Self {
161            api_key: key,
162            base_url: DEFAULT_BASE_URL.to_owned(),
163            agent: ureq::AgentBuilder::new()
164                .timeout(Duration::from_secs(60))
165                .build(),
166        })
167    }
168
169    /// Override the base URL (useful for testing or self-hosting).
170    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
171        self.base_url = base_url.into().trim_end_matches('/').to_owned();
172        self
173    }
174
175    /// Submit a generation job (returns task with task_id, status=Queued).
176    pub fn create_task(&self, input: &CreateTaskInput) -> Result<Task> {
177        let value = self.request("POST", "/tasks/create", Some(serde_json::to_value(input).unwrap()))?;
178        let task: Task = serde_json::from_value(value)
179            .map_err(|e| Error::Transport(format!("invalid task payload: {e}")))?;
180        Ok(task)
181    }
182
183    /// Fetch the current state of a task.
184    pub fn get_task(&self, task_id: &str) -> Result<Task> {
185        let value = self.request("GET", &format!("/tasks/{}", task_id), None)?;
186        serde_json::from_value(value).map_err(|e| Error::Transport(format!("invalid task payload: {e}")))
187    }
188
189    /// Create a task and poll until it reaches a terminal state.
190    pub fn run(&self, input: CreateTaskInput, options: Option<RunOptions>) -> Result<Task> {
191        let opts = options.unwrap_or_default();
192        let mut task = self.create_task(&input)?;
193        let deadline = Instant::now() + opts.max_wait;
194        while !task.is_done() {
195            if Instant::now() > deadline {
196                return Err(Error::Api {
197                    code: None,
198                    status: None,
199                    message: format!(
200                        "task {} did not finish within {:?}",
201                        task.task_id, opts.max_wait
202                    ),
203                });
204            }
205            thread::sleep(opts.poll_interval);
206            task = self.get_task(&task.task_id)?;
207        }
208        if task.task_status == Task::STATUS_FAILED {
209            return Err(Error::Api {
210                code: Some(Task::STATUS_FAILED as i64),
211                status: None,
212                message: task
213                    .msg
214                    .clone()
215                    .unwrap_or_else(|| format!("task {} failed", task.task_id)),
216            });
217        }
218        Ok(task)
219    }
220
221    fn request(
222        &self,
223        method: &str,
224        path: &str,
225        body: Option<serde_json::Value>,
226    ) -> Result<serde_json::Value> {
227        let url = format!("{}{}", self.base_url, path);
228        let req = self
229            .agent
230            .request(method, &url)
231            .set("Authorization", &format!("Bearer {}", self.api_key))
232            .set("Accept", "application/json");
233        let response = match body {
234            Some(b) => req.send_json(b),
235            None => req.call(),
236        };
237        match response {
238            Ok(r) => {
239                let envelope: serde_json::Value = r
240                    .into_json()
241                    .map_err(|e| Error::Transport(format!("invalid JSON: {e}")))?;
242                if let Some(code) = envelope.get("code").and_then(|v| v.as_i64()) {
243                    if code != 200 {
244                        let message = envelope
245                            .get("msg")
246                            .and_then(|v| v.as_str())
247                            .unwrap_or("business error")
248                            .to_owned();
249                        return Err(Error::Api {
250                            code: Some(code),
251                            status: None,
252                            message,
253                        });
254                    }
255                }
256                // Unwrap the `data` envelope; fall back to top-level for tolerance.
257                if let Some(data) = envelope.get("data").cloned() {
258                    if !data.is_null() {
259                        return Ok(data);
260                    }
261                }
262                Ok(envelope)
263            }
264            Err(ureq::Error::Status(status, resp)) => {
265                if status == 401 {
266                    return Err(Error::Api {
267                        code: None,
268                        status: Some(401),
269                        message: "unauthorized — check your OMNIFLASH_API_KEY (https://omniflash.net/)".to_owned(),
270                    });
271                }
272                let body = resp.into_string().unwrap_or_default();
273                let msg = serde_json::from_str::<serde_json::Value>(&body)
274                    .ok()
275                    .and_then(|v| v.get("msg").and_then(|m| m.as_str().map(str::to_owned)))
276                    .unwrap_or_else(|| format!("HTTP {status}"));
277                Err(Error::Api {
278                    code: None,
279                    status: Some(status),
280                    message: msg,
281                })
282            }
283            Err(e) => Err(Error::Transport(e.to_string())),
284        }
285    }
286}