use std::collections::HashMap;
use crate::{
api_definitions::{CreatePrediction, GetPrediction, PredictionStatus, PredictionsUrls},
errors::ReplicateError,
prediction::PredictionPayload,
};
use super::retry::{RetryPolicy, RetryStrategy};
pub fn parse_version(s: &str) -> Option<(&str, &str)> {
let mut parts = s.splitn(2, ':');
let model = parts.next()?;
let version = parts.next()?;
if !model.contains('/') {
return None;
}
Some((model, version))
}
#[allow(missing_docs)]
#[derive(Clone, Debug)]
pub struct PredictionClient {
pub parent: crate::config::Config,
pub id: String,
pub version: String,
pub urls: PredictionsUrls,
pub created_at: String,
pub status: PredictionStatus,
pub input: HashMap<String, serde_json::Value>,
pub error: Option<String>,
pub logs: Option<String>,
}
impl PredictionClient {
pub fn create<K: serde::Serialize, V: serde::ser::Serialize>(
rep: crate::config::Config,
version: &str,
inputs: HashMap<K, V>,
) -> Result<PredictionClient, ReplicateError> {
let (_model, version) = match parse_version(version) {
Some((model, version)) => (model, version),
None => return Err(ReplicateError::InvalidVersionString(version.to_string())),
};
let payload = PredictionPayload {
version: version.to_string(),
input: inputs,
};
let client = reqwest::blocking::Client::new();
let response = client
.post(format!("{}/predictions", rep.base_url))
.header("Authorization", format!("Token {}", rep.auth))
.header("User-Agent", &rep.user_agent)
.json(&payload)
.send()?;
if !response.status().is_success() {
return Err(ReplicateError::ResponseError(response.text()?));
}
if !response.status().is_success() {
return Err(ReplicateError::ResponseError(response.text()?));
}
let result: CreatePrediction = response.json()?;
Ok(Self {
parent: rep,
id: result.id,
version: result.version,
urls: result.urls,
created_at: result.created_at,
status: result.status,
input: result.input,
error: result.error,
logs: result.logs,
})
}
pub fn reload(&mut self) -> Result<(), ReplicateError> {
let client = reqwest::blocking::Client::new();
let response = client
.get(format!("{}/predictions/{}", self.parent.base_url, self.id))
.header("Authorization", format!("Token {}", self.parent.auth))
.header("User-Agent", &self.parent.user_agent)
.send()?;
if !response.status().is_success() {
return Err(ReplicateError::ResponseError(response.text()?));
}
let response_string = response.text()?;
let response_struct: GetPrediction = serde_json::from_str(&response_string)?;
self.id = response_struct.id;
self.version = response_struct.version;
self.urls = response_struct.urls;
self.created_at = response_struct.created_at;
self.status = response_struct.status;
self.input = response_struct.input;
self.error = response_struct.error;
self.logs = response_struct.logs;
Ok(())
}
pub fn cancel(&mut self) -> Result<(), ReplicateError> {
let client = reqwest::blocking::Client::new();
let response = client
.post(format!(
"{}/predictions/{}/cancel",
self.parent.base_url, self.id
))
.header("Authorization", format!("Token {}", &self.parent.auth))
.header("User-Agent", &self.parent.user_agent)
.send()?;
if !response.status().is_success() {
return Err(ReplicateError::ResponseError(response.text()?));
}
self.reload()?;
Ok(())
}
pub fn wait(&self) -> Result<GetPrediction, ReplicateError> {
let retry_policy = RetryPolicy::new(5, RetryStrategy::FixedDelay(1000));
let client = reqwest::blocking::Client::new();
loop {
let response = client
.get(format!("{}/predictions/{}", self.parent.base_url, self.id))
.header("Authorization", format!("Token {}", self.parent.auth))
.header("User-Agent", &self.parent.user_agent)
.send()?;
if !response.status().is_success() {
return Err(ReplicateError::ResponseError(response.text()?));
}
let response_string = response.text()?;
let response_struct: GetPrediction = serde_json::from_str(&response_string)?;
match response_struct.status {
PredictionStatus::succeeded
| PredictionStatus::failed
| PredictionStatus::canceled => {
return Ok(response_struct);
}
PredictionStatus::processing | PredictionStatus::starting => {
retry_policy.step();
}
}
}
}
}
#[cfg(test)]
mod tests {
use crate::{config::Config, Replicate};
use super::*;
use httpmock::{Method::POST, MockServer};
use serde_json::json;
#[test]
fn test_create() -> Result<(), ReplicateError> {
let server = MockServer::start();
let post_mock = server.mock(|when, then| {
when.method(POST).path("/predictions");
then.status(200).json_body_obj(&json!( {
"id": "ufawqhfynnddngldkgtslldrkq",
"version":
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
"urls": {
"get": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq",
"cancel":
"https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel",
},
"created_at": "2022-04-26T22:13:06.224088Z",
"started_at": None::<String>,
"completed_at": None::<String>,
"status": "starting",
"input": {
"text": "Alice",
},
"output": None::<String>,
"error": None::<String>,
"logs": None::<String>,
"metrics": {},
}
));
});
let config = Config {
auth: String::from("test"),
base_url: server.base_url(),
..Config::default()
};
let replicate = Replicate::new(config);
let mut input = HashMap::new();
input.insert("text", "Alice");
let result = replicate.predictions.create(
"owner/model:632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532",
input,
)?;
assert_eq!(result.id, "ufawqhfynnddngldkgtslldrkq");
post_mock.assert();
Ok(())
}
}