use crate::{api_definitions::GetModel, errors::ReplicateError, version::Version};
#[derive(Clone, Debug)]
pub struct Model {
pub parent: crate::config::Config,
pub versions: Version,
}
impl Model {
pub fn new(rep: crate::config::Config) -> Self {
let versions = Version::new(rep.clone());
Self {
parent: rep,
versions,
}
}
pub fn get(&self, model_owner: &str, model_name: &str) -> Result<GetModel, ReplicateError> {
let client = reqwest::blocking::Client::new();
let response = client
.get(format!(
"{}/models/{}/{}",
self.parent.base_url, model_owner, model_name
))
.header("Authorization", format!("Token {}", self.parent.auth))
.header("User-Agent", &self.parent.user_agent)
.send()?;
if !response.status().is_success() {
return Err(ReplicateError::ResponseError(response.text()?));
}
let response_string = response.text()?;
let response_struct: GetModel = serde_json::from_str(&response_string)?;
Ok(response_struct)
}
}
#[cfg(test)]
mod tests {
use crate::{config::Config, errors::ReplicateError, Replicate};
use httpmock::{Method::GET, MockServer};
use serde_json::json;
#[test]
fn test_get() -> Result<(), ReplicateError> {
let server = MockServer::start();
let get_mock = server.mock(|when, then| {
when.method(GET).path("/models/replicate/hello-world");
then.status(200).json_body_obj(&json!( {
"url": "https://replicate.com/replicate/hello-world",
"owner": "replicate",
"name": "hello-world",
"description": "A tiny model that says hello",
"visibility": "public",
"github_url": "https://github.com/replicate/cog-examples",
"paper_url": None::<String>,
"license_url": None::<String>,
"run_count": 12345,
"cover_image_url": "",
"default_example": {},
"latest_version": {}
}
));
});
let config = Config {
auth: String::from("test"),
base_url: server.base_url(),
..Config::default()
};
let replicate = Replicate::new(config);
let result = replicate.models.get("replicate", "hello-world")?;
println!("{:?}", result);
get_mock.assert();
Ok(())
}
}