1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
//! # Used to interact with the [Model Endpoints](https://replicate.com/docs/reference/http#models.get).
//!
//! The model module contains all the functionality for interacting with the model endpoints of the Replicate API.
//! Currently supports the following endpoint:
//! * [Get Model](https://replicate.com/docs/reference/http#models.get)
//! * [Get Model Versions](https://replicate.com/docs/reference/http#models.versions.get)
//! * [List Model Versions](https://replicate.com/docs/reference/http#models.versions.list)
//! * [Delete Model Version](https://replicate.com/docs/reference/http#models.versions.delete)
//!
//! # Example
//! ```
//! use replicate_rust::{Replicate, config::Config};
//!
//! let config = Config::default();
//! let replicate = Replicate::new(config);
//!
//! let model = replicate.models.get("replicate", "hello-world")?;
//! println!("Model : {:?}", model);
//! # Ok::<(), replicate_rust::errors::ReplicateError>(())
//! ```

use crate::{api_definitions::GetModel, errors::ReplicateError, version::Version};

// #[derive(Clone)]
/// Used to interact with the [Model Endpoints](https://replicate.com/docs/reference/http#models.get).
#[derive(Clone, Debug)]
pub struct Model {
    /// Holds a reference to a Configuration struct, which contains the base url,  auth token among other settings.
    pub parent: crate::config::Config,

    /// Holds a reference to a Version struct, which contains the functionality for interacting with the version endpoints of the Replicate API.
    pub versions: Version,
}

/// Model struct contains all the functionality for interacting with the model endpoints of the Replicate API.
/// Currently supports the following endpoint:
/// * [Get Model](https://replicate.com/docs/reference/http#models.get)
/// * [Get Model Versions](https://replicate.com/docs/reference/http#models.versions.get)
/// * [List Model Versions](https://replicate.com/docs/reference/http#models.versions.list)
/// * [Delete Model Version](https://replicate.com/docs/reference/http#models.versions.delete)
///
impl Model {
    /// Create a new Model struct.
    /// # Arguments
    /// * `rep` - The config (`crate::config::Config`) to use for authentication and communication.
    ///
    pub fn new(rep: crate::config::Config) -> Self {
        let versions = Version::new(rep.clone());
        Self {
            parent: rep,
            versions,
        }
    }

    /// Get the details of a model.
    /// # Arguments
    /// * `model_owner` - The owner of the model.
    /// * `model_name` - The name of the model.
    ///
    /// # Example
    /// ```
    /// use replicate_rust::{Replicate, config::Config};
    ///
    /// let config = Config::default();
    /// let replicate = Replicate::new(config);
    ///
    /// let model = replicate.models.get("replicate", "hello-world")?;
    /// println!("Model : {:?}", model);
    ///
    /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
    /// ```
    pub fn get(&self, model_owner: &str, model_name: &str) -> Result<GetModel, ReplicateError> {
        let client = reqwest::blocking::Client::new();

        let response = client
            .get(format!(
                "{}/models/{}/{}",
                self.parent.base_url, model_owner, model_name
            ))
            .header("Authorization", format!("Token {}", self.parent.auth))
            .header("User-Agent", &self.parent.user_agent)
            .send()?;

        if !response.status().is_success() {
            return Err(ReplicateError::ResponseError(response.text()?));
        }

        let response_string = response.text()?;
        let response_struct: GetModel = serde_json::from_str(&response_string)?;

        Ok(response_struct)
    }
}

#[cfg(test)]
mod tests {
    use crate::{config::Config, errors::ReplicateError, Replicate};

    use httpmock::{Method::GET, MockServer};
    use serde_json::json;

    #[test]
    fn test_get() -> Result<(), ReplicateError> {
        let server = MockServer::start();

        let get_mock = server.mock(|when, then| {
            when.method(GET).path("/models/replicate/hello-world");
            then.status(200).json_body_obj(&json!( {
                "url": "https://replicate.com/replicate/hello-world",
                "owner": "replicate",
                "name": "hello-world",
                "description": "A tiny model that says hello",
                "visibility": "public",
                "github_url": "https://github.com/replicate/cog-examples",
                "paper_url": None::<String>,
                "license_url": None::<String>,
                "run_count": 12345,
                "cover_image_url": "",
                "default_example": {},
                "latest_version": {}
            }
            ));
        });

        let config = Config {
            auth: String::from("test"),
            base_url: server.base_url(),
            ..Config::default()
        };
        let replicate = Replicate::new(config);

        let result = replicate.models.get("replicate", "hello-world")?;

        println!("{:?}", result);

        // Ensure the mocks were called as expected
        get_mock.assert();

        Ok(())
    }
}