1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
use reqwest::{header, Response};

use super::error::FalError;

pub enum ClientCredentials {
    Key(String),
    FromEnv(String),
    KeyPair(String, String),
}

const ENV_CANDIDATES: [&str; 2] = ["FAL_KEY", "FAL_API_KEY"];

impl ClientCredentials {
    pub fn from_env() -> Self {
        dotenvy::dotenv().ok();

        for candidate in ENV_CANDIDATES {
            if let Ok(key) = std::env::var(candidate) {
                return Self::Key(key);
            }
        }

        if let Ok(key_id) = std::env::var("FAL_KEY_ID") {
            if let Ok(secret) = std::env::var("FAL_SECRET") {
                return Self::KeyPair(key_id, secret);
            }
        }

        panic!("FAL_KEY or FAL_KEY_ID and FAL_SECRET must be set in the environment");
    }

    pub fn from_key(key: &str) -> Self {
        Self::Key(key.to_string())
    }

    pub fn from_key_pair(key_id: &str, secret: &str) -> Self {
        Self::KeyPair(key_id.to_string(), secret.to_string())
    }
}

pub struct FalClient {
    credentials: ClientCredentials,
}

impl FalClient {
    pub fn build_url(&self, path: &str) -> String {
        let host = std::env::var("FAL_RUN_HOST").unwrap_or("fal.run".to_string());
        let base_url = format!("https://{}", host);
        format!("{}/{}", base_url, path.trim_start_matches('/'))
    }

    pub fn new(credentials: ClientCredentials) -> Self {
        Self { credentials }
    }

    fn client(&self) -> reqwest::Client {
        let mut header = header::HeaderMap::new();
        let creds = match &self.credentials {
            ClientCredentials::Key(key) => key.clone(),
            ClientCredentials::KeyPair(key_id, secret) => format!("{}:{}", key_id, secret),
            ClientCredentials::FromEnv(_) => panic!("FAL_API_KEY must be set in the environment"),
        };

        header.insert("Authorization", format!("Key {}", creds).parse().unwrap());

        reqwest::Client::builder()
            .default_headers(header)
            .build()
            .unwrap()
    }

    pub async fn run<T: serde::Serialize>(
        &self,
        funtion_id: &str,
        inputs: T,
    ) -> Result<Response, FalError> {
        let client = self.client();
        let url = self.build_url(funtion_id);
        let res = client.post(&url).json(&inputs).send().await;

        match res {
            Ok(res) => {
                if res.status().is_success() {
                    Ok(res)
                } else {
                    Err(FalError::InvalidCredentials)
                }
            }
            Err(e) => Err(FalError::RequestError(e)),
        }
    }
}