use crate::{
api_definitions::{GetCollectionModels, ListCollectionModels},
errors::ReplicateError,
};
#[derive(Clone, Debug)]
pub struct Collection {
pub parent: crate::config::Config,
}
impl Collection {
pub fn new(rep: crate::config::Config) -> Self {
Self { parent: rep }
}
pub fn get(&self, collection_slug: &str) -> Result<GetCollectionModels, ReplicateError> {
let client = reqwest::blocking::Client::new();
let response = client
.get(format!(
"{}/collections/{}",
self.parent.base_url, collection_slug
))
.header("Authorization", format!("Token {}", self.parent.auth))
.header("User-Agent", &self.parent.user_agent)
.send()?;
if !response.status().is_success() {
return Err(ReplicateError::ResponseError(response.text()?));
}
let response_string = response.text()?;
let response_struct: GetCollectionModels = serde_json::from_str(&response_string)?;
Ok(response_struct)
}
pub fn list(&self) -> Result<ListCollectionModels, ReplicateError> {
let client = reqwest::blocking::Client::new();
let response = client
.get(format!("{}/collections", self.parent.base_url))
.header("Authorization", format!("Token {}", self.parent.auth))
.header("User-Agent", &self.parent.user_agent)
.send()?;
if !response.status().is_success() {
return Err(ReplicateError::ResponseError(response.text()?));
}
let response_string = response.text()?;
let response_struct: ListCollectionModels = serde_json::from_str(&response_string)?;
Ok(response_struct)
}
}
#[cfg(test)]
mod tests {
use crate::{config::Config, errors::ReplicateError, Replicate};
use httpmock::{Method::GET, MockServer};
use serde_json::json;
#[test]
fn test_get() -> Result<(), ReplicateError> {
let server = MockServer::start();
let get_mock = server.mock(|when, then| {
when.method(GET)
.path("/collections/super-resolution");
then.status(200).json_body_obj(&json!( {
"name": "Super resolution",
"slug": "super-resolution",
"description": "Upscaling models that create high-quality images from low-quality images.",
"models": [],
}));
});
let config = Config {
auth: String::from("test"),
base_url: server.base_url(),
..Config::default()
};
let replicate = Replicate::new(config);
let result = replicate.collections.get("super-resolution");
assert_eq!(result?.name, "Super resolution");
get_mock.assert();
Ok(())
}
#[test]
fn test_list() -> Result<(), ReplicateError> {
let server = MockServer::start();
let get_mock = server.mock(|when, then| {
when.method(GET)
.path("/collections");
then.status(200).json_body_obj(&json!( {
"results": [
{
"name": "Super resolution",
"slug": "super-resolution",
"description": "Upscaling models that create high-quality images from low-quality images.",
},
{
"name": "Image classification",
"slug": "image-classification",
"description": "Models that classify images.",
},
],
"next": None::<String>,
"previous": None::<String>,
}));
});
let config: Config = Config {
auth: String::from("test"),
base_url: server.base_url(),
..Config::default()
};
let replicate = Replicate::new(config);
let result = replicate.collections.list()?;
assert_eq!(result.results.len(), 2);
get_mock.assert();
Ok(())
}
}