use anyhow::anyhow;
use futures_lite::io::AsyncReadExt;
use isahc::{prelude::*, Request};
use serde::Deserialize;
use serde_json::Value;
use crate::config::ReplicateConfig;
#[derive(Debug, Deserialize)]
struct ModelVersionError {
detail: String,
}
#[derive(Debug, Deserialize, Clone)]
pub struct ModelVersion {
pub id: String,
pub created_at: String,
pub cog_version: String,
pub openapi_schema: serde_json::Value,
}
#[derive(Debug, Deserialize)]
pub struct ModelVersions {
pub next: Option<String>,
pub previous: Option<String>,
pub results: Vec<ModelVersion>,
}
#[derive(Debug, Deserialize)]
pub struct Models {
pub next: Option<String>,
pub previous: Option<String>,
pub results: Vec<Model>,
}
#[derive(Deserialize, Debug)]
pub struct Model {
pub url: String,
pub owner: String,
pub name: String,
pub description: String,
pub visibility: String,
pub github_url: String,
pub paper_url: Option<String>,
pub license_url: Option<String>,
pub run_count: usize,
pub cover_image_url: String,
pub default_example: Value,
pub latest_version: ModelVersion,
}
pub struct ModelClient {
client: ReplicateConfig,
}
impl ModelClient {
pub fn from(client: ReplicateConfig) -> Self {
ModelClient { client }
}
pub async fn get(&self, owner: &str, name: &str) -> anyhow::Result<Model> {
let api_key = self.client.get_api_key()?;
let base_url = self.client.get_base_url();
let endpoint = format!("{base_url}/models/{owner}/{name}");
let response = Request::get(endpoint)
.header("Authorization", format!("Token {api_key}"))
.body({})?
.send_async()
.await?;
let mut bytes = Vec::new();
response.into_body().read_to_end(&mut bytes).await?;
let model: Model = serde_json::from_slice(&bytes)?;
anyhow::Ok(model)
}
pub async fn get_specific_version(
&self,
owner: &str,
name: &str,
version_id: &str,
) -> anyhow::Result<Model> {
let api_key = self.client.get_api_key()?;
let base_url = self.client.get_base_url();
let endpoint = format!("{base_url}/models/{owner}/{name}/versions/{version_id}");
let response = Request::get(endpoint)
.header("Authorization", format!("Token {api_key}"))
.body({})?
.send_async()
.await?;
let mut bytes = Vec::new();
response.into_body().read_to_end(&mut bytes).await?;
let model: Model = serde_json::from_slice(&bytes)?;
anyhow::Ok(model)
}
pub async fn get_latest_version(
&self,
owner: &str,
name: &str,
) -> anyhow::Result<ModelVersion> {
let all_versions = self.list_versions(owner, name).await?;
let latest_version = all_versions
.results
.get(0)
.ok_or(anyhow!("no versions found for {owner}/{name}"))?;
anyhow::Ok(latest_version.clone())
}
pub async fn list_versions(&self, owner: &str, name: &str) -> anyhow::Result<ModelVersions> {
let base_url = self.client.get_base_url();
let api_key = self.client.get_api_key()?;
let endpoint = format!("{base_url}/models/{owner}/{name}/versions");
let mut response = Request::get(endpoint)
.header("Authorization", format!("Token {api_key}"))
.body({})?
.send_async()
.await?;
let mut bytes = Vec::new();
response.body_mut().read_to_end(&mut bytes).await?;
if response.status().is_success() {
let data: ModelVersions = serde_json::from_slice(&bytes)?;
anyhow::Ok(data)
} else {
let data: ModelVersionError = serde_json::from_slice(&bytes)?;
Err(anyhow!(data.detail))
}
}
pub async fn get_models(&self) -> anyhow::Result<Models> {
let base_url = self.client.get_base_url();
let api_key = self.client.get_api_key()?;
let endpoint = format!("{base_url}/models");
let mut response = Request::get(endpoint)
.header("Authorization", format!("Token {api_key}"))
.body({})?
.send_async()
.await?;
let mut bytes = Vec::new();
response.body_mut().read_to_end(&mut bytes).await?;
let models: Models = serde_json::from_slice(&bytes)?;
anyhow::Ok(models)
}
}
#[cfg(test)]
mod tests {
use super::*;
use httpmock::prelude::*;
use serde_json::json;
#[tokio::test]
async fn test_get_model() {
let mock_server = MockServer::start();
let model_mock = mock_server.mock(|when, then| {
when.method(GET).path("/models/replicate/hello-world");
then.status(200).json_body_obj(&json!({
"url": "https://replicate.com/replicate/hello-world",
"owner": "replicate",
"name": "hello-world",
"description": "A tiny model that says hello",
"visibility": "public",
"github_url": "https://github.com/replicate/cog-examples",
"paper_url": null,
"license_url": null,
"run_count": 5681081,
"cover_image_url": "...",
"default_example": null,
"latest_version": {
"id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
"created_at": "2022-04-26T19:29:04.418669Z",
"cog_version": "0.3.0",
"openapi_schema": {}
}
}));
});
let client = ReplicateConfig::test(mock_server.base_url()).unwrap();
let model_client = ModelClient::from(client);
model_client.get("replicate", "hello-world").await.unwrap();
model_mock.assert();
}
#[tokio::test]
async fn test_get_specific_version() {
let mock_server = MockServer::start();
let model_mock = mock_server.mock(|when, then| {
when.method(GET)
.path("/models/replicate/hello-world/versions/1234");
then.status(200).json_body_obj(&json!({
"url": "https://replicate.com/replicate/hello-world",
"owner": "replicate",
"name": "hello-world",
"description": "A tiny model that says hello",
"visibility": "public",
"github_url": "https://github.com/replicate/cog-examples",
"paper_url": null,
"license_url": null,
"run_count": 5681081,
"cover_image_url": "...",
"default_example": null,
"latest_version": {
"id": "1234",
"created_at": "2022-04-26T19:29:04.418669Z",
"cog_version": "0.3.0",
"openapi_schema": {}
}
}));
});
let client = ReplicateConfig::test(mock_server.base_url()).unwrap();
let model_client = ModelClient::from(client);
model_client
.get_specific_version("replicate", "hello-world", "1234")
.await
.unwrap();
model_mock.assert();
}
#[tokio::test]
async fn test_list_model_versions() {
let mock_server = MockServer::start();
let model_mock = mock_server.mock(|when, then| {
when.method(GET)
.path("/models/replicate/hello-world/versions");
then.status(200).json_body_obj(&json!({
"next": null,
"previous": null,
"results": [{
"id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
"created_at": "2022-04-26T19:29:04.418669Z",
"cog_version": "0.3.0",
"openapi_schema": null
}]
}));
});
let client = ReplicateConfig::test(mock_server.base_url()).unwrap();
let model_client = ModelClient::from(client);
model_client
.list_versions("replicate", "hello-world")
.await
.unwrap();
model_mock.assert();
}
#[tokio::test]
async fn test_get_latest_version() {
let mock_server = MockServer::start();
let model_mock = mock_server.mock(|when, then| {
when.method(GET)
.path("/models/replicate/hello-world/versions");
then.status(200).json_body_obj(&json!({
"next": null,
"previous": null,
"results": [{
"id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
"created_at": "2022-04-26T19:29:04.418669Z",
"cog_version": "0.3.0",
"openapi_schema": null
}]
}));
});
let client = ReplicateConfig::test(mock_server.base_url()).unwrap();
let model_client = ModelClient::from(client);
model_client
.get_latest_version("replicate", "hello-world")
.await
.unwrap();
model_mock.assert();
}
#[tokio::test]
async fn test_get_models() {
let mock_server = MockServer::start();
let model_mock = mock_server.mock(|when, then| {
when.method(GET).path("/models");
then.status(200).json_body_obj(&json!({
"next": "some pagination string or null",
"previous": "some pagination string or null",
"results": [
{
"url": "https://modelhomepage.example.com",
"owner": "jdoe",
"name": "super-cool-model",
"description": "A model that predicts something very cool.",
"visibility": "public",
"github_url": "https://github.com/jdoe/super-cool-model",
"paper_url": "https://research.example.com/super-cool-model-paper.pdf",
"license_url": null,
"run_count": 420,
"cover_image_url": "https://cdn.example.com/images/super-cool-model-cover.jpg",
"default_example": {
"input": "Example input data for the model."
},
"latest_version": {
"id": "v1.0.0",
"created_at": "2022-01-01T12:00:00Z",
"cog_version": "0.2",
"openapi_schema": null
}
},
{
"url": "https://anothermodelhomepage.example.com",
"owner": "asmith",
"name": "another-awesome-model",
"description": "This model does awesome things with data.",
"visibility": "private",
"github_url": "https://github.com/asmith/another-awesome-model",
"paper_url": null,
"license_url": "https://licenses.example.com/another-awesome-model-license.txt",
"run_count": 150,
"cover_image_url": "https://cdn.example.com/images/another-awesome-model-cover.jpg",
"default_example": {
"input": "Some example input for this awesome model."
},
"latest_version": {
"id": "v1.2.3",
"created_at": "2023-02-15T08:30:00Z",
"cog_version": "0.2",
"openapi_schema": null
}
}
]}));
});
let client = ReplicateConfig::test(mock_server.base_url()).unwrap();
let model_client = ModelClient::from(client);
model_client.get_models().await.unwrap();
model_mock.assert();
}
}