1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
//! # Used to interact with the [Model Endpoints](https://replicate.com/docs/reference/http#models.get).
//!
//! The model module contains all the functionality for interacting with the model endpoints of the Replicate API.
//! Currently supports the following endpoint:
//! * [Get Model](https://replicate.com/docs/reference/http#models.get)
//! * [Get Model Versions](https://replicate.com/docs/reference/http#models.versions.get)
//! * [List Model Versions](https://replicate.com/docs/reference/http#models.versions.list)
//! * [Delete Model Version](https://replicate.com/docs/reference/http#models.versions.delete)
//!
//! # Example
//! ```
//! use replicate_rust::{Replicate, config::Config};
//!
//! let config = Config::default();
//! let replicate = Replicate::new(config);
//!
//! let model = replicate.models.get("replicate", "hello-world")?;
//! println!("Model : {:?}", model);
//! # Ok::<(), replicate_rust::errors::ReplicateError>(())
//! ```
use crate::{api_definitions::GetModel, errors::ReplicateError, version::Version};
// #[derive(Clone)]
/// Used to interact with the [Model Endpoints](https://replicate.com/docs/reference/http#models.get).
#[derive(Clone, Debug)]
pub struct Model {
/// Holds a reference to a Configuration struct, which contains the base url, auth token among other settings.
pub parent: crate::config::Config,
/// Holds a reference to a Version struct, which contains the functionality for interacting with the version endpoints of the Replicate API.
pub versions: Version,
}
/// Model struct contains all the functionality for interacting with the model endpoints of the Replicate API.
/// Currently supports the following endpoint:
/// * [Get Model](https://replicate.com/docs/reference/http#models.get)
/// * [Get Model Versions](https://replicate.com/docs/reference/http#models.versions.get)
/// * [List Model Versions](https://replicate.com/docs/reference/http#models.versions.list)
/// * [Delete Model Version](https://replicate.com/docs/reference/http#models.versions.delete)
///
impl Model {
/// Create a new Model struct.
/// # Arguments
/// * `rep` - The config (`crate::config::Config`) to use for authentication and communication.
///
pub fn new(rep: crate::config::Config) -> Self {
let versions = Version::new(rep.clone());
Self {
parent: rep,
versions,
}
}
/// Get the details of a model.
/// # Arguments
/// * `model_owner` - The owner of the model.
/// * `model_name` - The name of the model.
///
/// # Example
/// ```
/// use replicate_rust::{Replicate, config::Config};
///
/// let config = Config::default();
/// let replicate = Replicate::new(config);
///
/// let model = replicate.models.get("replicate", "hello-world")?;
/// println!("Model : {:?}", model);
///
/// # Ok::<(), replicate_rust::errors::ReplicateError>(())
/// ```
pub fn get(&self, model_owner: &str, model_name: &str) -> Result<GetModel, ReplicateError> {
let client = reqwest::blocking::Client::new();
let response = client
.get(format!(
"{}/models/{}/{}",
self.parent.base_url, model_owner, model_name
))
.header("Authorization", format!("Token {}", self.parent.auth))
.header("User-Agent", &self.parent.user_agent)
.send()?;
if !response.status().is_success() {
return Err(ReplicateError::ResponseError(response.text()?));
}
let response_string = response.text()?;
let response_struct: GetModel = serde_json::from_str(&response_string)?;
Ok(response_struct)
}
}
#[cfg(test)]
mod tests {
use crate::{config::Config, errors::ReplicateError, Replicate};
use httpmock::{Method::GET, MockServer};
use serde_json::json;
#[test]
fn test_get() -> Result<(), ReplicateError> {
let server = MockServer::start();
let get_mock = server.mock(|when, then| {
when.method(GET).path("/models/replicate/hello-world");
then.status(200).json_body_obj(&json!( {
"url": "https://replicate.com/replicate/hello-world",
"owner": "replicate",
"name": "hello-world",
"description": "A tiny model that says hello",
"visibility": "public",
"github_url": "https://github.com/replicate/cog-examples",
"paper_url": None::<String>,
"license_url": None::<String>,
"run_count": 12345,
"cover_image_url": "",
"default_example": {},
"latest_version": {}
}
));
});
let config = Config {
auth: String::from("test"),
base_url: server.base_url(),
..Config::default()
};
let replicate = Replicate::new(config);
let result = replicate.models.get("replicate", "hello-world")?;
println!("{:?}", result);
// Ensure the mocks were called as expected
get_mock.assert();
Ok(())
}
}