use serde::{Deserialize, Serialize};
use std::env;
use std::thread;
use std::time::{Duration, Instant};
pub const DEFAULT_BASE_URL: &str = "https://omniflash.net/api/v1";
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct Task {
pub task_id: String,
#[serde(default = "default_status")]
pub task_status: u8,
#[serde(default)]
pub image_url: Option<String>,
#[serde(default)]
pub video_url: Option<String>,
#[serde(default)]
pub audio_url: Option<String>,
#[serde(default)]
pub request_id: Option<String>,
#[serde(default)]
pub task_type: Option<String>,
#[serde(default)]
pub model_id: Option<String>,
#[serde(default)]
pub credits: Option<i64>,
#[serde(default)]
pub created_at: Option<i64>,
#[serde(default)]
pub msg: Option<String>,
}
fn default_status() -> u8 {
Task::STATUS_QUEUED
}
impl Task {
pub const STATUS_QUEUED: u8 = 1;
pub const STATUS_RUNNING: u8 = 2;
pub const STATUS_SUCCESS: u8 = 3;
pub const STATUS_FAILED: u8 = 4;
pub fn is_done(&self) -> bool {
self.task_status == Self::STATUS_SUCCESS || self.task_status == Self::STATUS_FAILED
}
pub fn output_url(&self) -> Option<&str> {
self.video_url
.as_deref()
.or(self.image_url.as_deref())
.or(self.audio_url.as_deref())
}
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct CreateTaskInput {
pub model_id: String,
pub prompt: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_urls: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub aspect_ratio: Option<String>,
}
#[derive(Debug, Clone)]
pub struct RunOptions {
pub poll_interval: Duration,
pub max_wait: Duration,
}
impl Default for RunOptions {
fn default() -> Self {
Self {
poll_interval: Duration::from_secs(3),
max_wait: Duration::from_secs(600),
}
}
}
#[derive(Debug)]
pub enum Error {
MissingApiKey,
Transport(String),
Api {
code: Option<i64>,
status: Option<u16>,
message: String,
},
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::MissingApiKey => write!(
f,
"missing API key — pass via builder or set OMNIFLASH_API_KEY. Get one at https://omniflash.net/"
),
Error::Transport(m) => write!(f, "transport error: {m}"),
Error::Api { code, status, message } => {
write!(f, "api error")?;
if let Some(s) = status {
write!(f, " status={s}")?;
}
if let Some(c) = code {
write!(f, " code={c}")?;
}
write!(f, ": {message}")
}
}
}
}
impl std::error::Error for Error {}
pub type Result<T> = std::result::Result<T, Error>;
pub struct OmniFlash {
api_key: String,
base_url: String,
agent: ureq::Agent,
}
impl OmniFlash {
pub fn new(api_key: Option<&str>) -> Result<Self> {
let key = api_key
.map(str::to_owned)
.or_else(|| env::var("OMNIFLASH_API_KEY").ok())
.filter(|s| !s.is_empty())
.ok_or(Error::MissingApiKey)?;
Ok(Self {
api_key: key,
base_url: DEFAULT_BASE_URL.to_owned(),
agent: ureq::AgentBuilder::new()
.timeout(Duration::from_secs(60))
.build(),
})
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into().trim_end_matches('/').to_owned();
self
}
pub fn create_task(&self, input: &CreateTaskInput) -> Result<Task> {
let value = self.request("POST", "/tasks/create", Some(serde_json::to_value(input).unwrap()))?;
let task: Task = serde_json::from_value(value)
.map_err(|e| Error::Transport(format!("invalid task payload: {e}")))?;
Ok(task)
}
pub fn get_task(&self, task_id: &str) -> Result<Task> {
let value = self.request("GET", &format!("/tasks/{}", task_id), None)?;
serde_json::from_value(value).map_err(|e| Error::Transport(format!("invalid task payload: {e}")))
}
pub fn run(&self, input: CreateTaskInput, options: Option<RunOptions>) -> Result<Task> {
let opts = options.unwrap_or_default();
let mut task = self.create_task(&input)?;
let deadline = Instant::now() + opts.max_wait;
while !task.is_done() {
if Instant::now() > deadline {
return Err(Error::Api {
code: None,
status: None,
message: format!(
"task {} did not finish within {:?}",
task.task_id, opts.max_wait
),
});
}
thread::sleep(opts.poll_interval);
task = self.get_task(&task.task_id)?;
}
if task.task_status == Task::STATUS_FAILED {
return Err(Error::Api {
code: Some(Task::STATUS_FAILED as i64),
status: None,
message: task
.msg
.clone()
.unwrap_or_else(|| format!("task {} failed", task.task_id)),
});
}
Ok(task)
}
fn request(
&self,
method: &str,
path: &str,
body: Option<serde_json::Value>,
) -> Result<serde_json::Value> {
let url = format!("{}{}", self.base_url, path);
let req = self
.agent
.request(method, &url)
.set("Authorization", &format!("Bearer {}", self.api_key))
.set("Accept", "application/json");
let response = match body {
Some(b) => req.send_json(b),
None => req.call(),
};
match response {
Ok(r) => {
let envelope: serde_json::Value = r
.into_json()
.map_err(|e| Error::Transport(format!("invalid JSON: {e}")))?;
if let Some(code) = envelope.get("code").and_then(|v| v.as_i64()) {
if code != 200 {
let message = envelope
.get("msg")
.and_then(|v| v.as_str())
.unwrap_or("business error")
.to_owned();
return Err(Error::Api {
code: Some(code),
status: None,
message,
});
}
}
if let Some(data) = envelope.get("data").cloned() {
if !data.is_null() {
return Ok(data);
}
}
Ok(envelope)
}
Err(ureq::Error::Status(status, resp)) => {
if status == 401 {
return Err(Error::Api {
code: None,
status: Some(401),
message: "unauthorized — check your OMNIFLASH_API_KEY (https://omniflash.net/)".to_owned(),
});
}
let body = resp.into_string().unwrap_or_default();
let msg = serde_json::from_str::<serde_json::Value>(&body)
.ok()
.and_then(|v| v.get("msg").and_then(|m| m.as_str().map(str::to_owned)))
.unwrap_or_else(|| format!("HTTP {status}"));
Err(Error::Api {
code: None,
status: Some(status),
message: msg,
})
}
Err(e) => Err(Error::Transport(e.to_string())),
}
}
}