use std::marker::PhantomData;
use serde::{de::DeserializeOwned, Serialize};
use crate::{
queue::{Queue, QueueResponse},
FalError,
};
#[derive(Debug)]
pub struct FalRequest<Params: Serialize, Response: DeserializeOwned> {
pub client: reqwest::Client,
pub endpoint: String,
pub params: Params,
pub api_key: Option<String>,
phantom: PhantomData<Response>,
}
impl<Params: Serialize, Response: DeserializeOwned> FalRequest<Params, Response> {
pub fn new(endpoint: impl Into<String>, params: Params) -> Self {
Self {
client: reqwest::Client::new(),
endpoint: endpoint.into(),
params,
api_key: std::env::var("FAL_API_KEY").ok(),
phantom: PhantomData,
}
}
pub fn with_client(mut self, client: reqwest::Client) -> Self {
self.client = client;
self
}
pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self
}
pub async fn send(self) -> Result<Response, FalError> {
let response = self
.client
.post(format!("https://fal.run/{}", self.endpoint))
.json(&self.params)
.header(
"Authorization",
format!(
"Key {}",
self.api_key.expect(
"No fal API key provided, and FAL_API_KEY environment variable is not set"
)
),
)
.header("Content-Type", "application/json")
.send()
.await?;
if response.status() != 200 {
let error = response.text().await?;
return Err(error.into());
}
Ok(response.error_for_status()?.json().await?)
}
pub async fn queue(self) -> Result<Queue<Response>, FalError> {
let key = self
.api_key
.expect("No fal API key provided, and FAL_API_KEY environment variable is not set");
let response = self
.client
.post(format!("https://queue.fal.run/{}", self.endpoint))
.json(&self.params)
.header("Authorization", format!("Key {}", &key))
.header("Content-Type", "application/json")
.send()
.await?;
if response.status() != 200 {
let error = response.text().await?;
return Err(error.into());
}
let payload: QueueResponse = response.error_for_status()?.json().await?;
Ok(Queue::new(self.client, self.endpoint, key, payload))
}
}