1use std::time::SystemTime;
2use reqwest::StatusCode;
3use serde_json::Value;
4use uuid::Uuid;
5use lazy_static::lazy_static;
6
7use crate::types::BananaError;
8use crate::types::BananaResponse;
9use crate::types::Payload;
10use crate::types::CheckPayload;
11
12lazy_static! {
14 static ref BANANA_URL: String = {
15 match std::env::var("BANANA_URL") {
16 Ok(v) => {
17 println!("dev mode");
18 if v == "local".to_string() {
19 return "http://localhost/".to_string();
20 } else {
21 return v;
22 }
23 },
24 Err(_) => {
25 return "https://api.banana.dev/".to_string();
26 }
27 };
28 };
29}
30
31pub async fn run_main(api_key: &str, model_key: &str, model_inputs: Value) -> Result<BananaResponse, BananaError> {
32 match start_api(api_key, model_key, model_inputs).await {
33 Ok(res) => {
34 match res.finished {
35 Some(value) => {
36 if value == true {
37 return Ok(res)
38 } else {
39 match res.call_i_d {
40 Some(value) => {
41 loop {
42 println!("polling...");
43 match check_api(api_key, &value).await {
44 Ok(res) => {
45 if res.message.to_ascii_lowercase() == "success" {
46 return Ok(res)
47 }
48 }
49 Err(e) => return Err(e)
50 }
51 }
52
53 },
54 None => return Err(BananaError::ResponseError("call id returned undefined".to_string()))
55 }
56 }
57 },
58 None => return Err(BananaError::ResponseError("finished returned undefined.".to_string()))
59 }
60 },
61 Err(e) => Err(e)
62 }
63}
64
65pub async fn start_main(api_key: &str, model_key: &str, model_inputs: Value) -> Result<String, BananaError> {
66 match start_api(api_key, model_key, model_inputs).await {
67 Ok(res) => {
68 match res.call_i_d {
69 Some(value) => return Ok(value),
70 None => return Err(BananaError::ResponseError("call id returned undefined.".to_string()))
71 }
72 }
73 Err(e) => Err(e)
74 }
75}
76
77pub async fn check_main(api_key: &str, call_id: &String) -> Result<BananaResponse, BananaError> {
78 check_api(api_key, call_id).await
79}
80
81async fn start_api(api_key: &str, model_key: &str, model_inputs: Value) -> Result<BananaResponse, BananaError> {
86
87 let mut url_start = BANANA_URL.clone();
89 let route_start = "start/v4/";
90 url_start.push_str(route_start);
91
92 let created = match SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) {
93 Ok(time) => time.as_millis() as usize,
94 Err(e) => return Err(BananaError::TimeError(e))
95 };
96
97 let payload = Payload {
98 id: Uuid::new_v4().to_string(),
99 created,
100 model_key: model_key.to_string(),
101 api_key: api_key.to_string(),
102 model_inputs,
103 start_only: false
104 };
105
106 let client = reqwest::Client::new();
107
108 match client.post(&url_start).json(&payload).send().await {
109 Ok(res) => {
110
111 let status = res.status();
112
113 if status != StatusCode::OK {
114 Err(BananaError::ServerError(status.to_string()))
115 } else {
116 match res.json::<BananaResponse>().await {
117 Ok(res) => {
118 if res.message.to_ascii_lowercase().contains("error") {
119 return Err(BananaError::ModelError(res.message))
120 } else {
121 Ok(res)
122 }
123 },
124 Err(e) => Err(BananaError::JsonError(e))
125
126 }
127 }
128 },
129 Err(e) => Err(BananaError::ConnectionError(e))
130 }
131}
132
133async fn check_api(api_key: &str, call_id: &String) -> Result<BananaResponse, BananaError> {
134
135 let mut url_check = BANANA_URL.clone();
136 println!("Hitting endpoint: {}", url_check);
137 let route_start = "check/v4/";
138 url_check.push_str(route_start);
139
140 let created = match SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) {
141 Ok(time) => time.as_millis() as usize,
142 Err(e) => return Err(BananaError::TimeError(e))
143 };
144
145 let payload = CheckPayload {
146 id: Uuid::new_v4().to_string(),
147 created,
148 long_poll: true,
149 call_i_d: call_id.to_string(),
150 api_key: api_key.to_string()
151 };
152
153 let client = reqwest::Client::new();
154
155 match client.post(url_check).json(&payload).send().await {
156 Ok(res) => {
157
158 let status = res.status();
159
160 if status != StatusCode::OK {
161 Err(BananaError::ServerError(status.to_string()))
162 } else {
163 let json = res.json::<BananaResponse>().await;
164 match json {
165 Ok(res) => {
166 if res.message.to_ascii_lowercase().contains("error") {
167 return Err(BananaError::ModelError(res.message))
168 } else {
169 Ok(res)
170 }
171 },
172 Err(e) => Err(BananaError::JsonError(e))
173 }
174 }
175 },
176 Err(e) => Err(BananaError::ConnectionError(e))
177 }
178}