pinecone_sdk/pinecone/
inference.rs

1use crate::openapi::apis::inference_api;
2use crate::openapi::models::{EmbedRequest, EmbedRequestInputsInner};
3use crate::pinecone::PineconeClient;
4use crate::utils::errors::PineconeError;
5
6use crate::models::{EmbedRequestParameters, EmbeddingsList};
7
8impl PineconeClient {
9    /// Generate embeddings for input data.
10    ///
11    /// ### Arguments
12    /// * `model: &str` - The model to use for embedding.
13    /// * `parameters: Option<EmbedRequestParameters>` - Model-specific parameters.
14    /// * `inputs: &Vec<&str>` - The input data to embed.
15    ///
16    /// ### Return
17    /// * `Result<EmbeddingsList, PineconeError>`
18    ///
19    /// ### Example
20    /// ```no_run
21    ///
22    /// # #[tokio::main]
23    /// # async fn main() -> Result<(), pinecone_sdk::utils::errors::PineconeError> {
24    ///
25    /// let pinecone = pinecone_sdk::pinecone::default_client()?;
26    /// let response = pinecone.embed("multilingual-e5-large", None, &vec!["Hello, world!"]).await.expect("Failed to embed");
27    ///
28    /// # Ok(())
29    /// # }
30    /// ```
31    pub async fn embed(
32        &self,
33        model: &str,
34        parameters: Option<EmbedRequestParameters>,
35        inputs: &Vec<&str>,
36    ) -> Result<EmbeddingsList, PineconeError> {
37        let request = EmbedRequest {
38            model: model.to_string(),
39            parameters: parameters.map(|x| Box::new(x)),
40            inputs: inputs
41                .iter()
42                .map(|&x| EmbedRequestInputsInner {
43                    text: Some(x.to_string()),
44                })
45                .collect(),
46        };
47
48        let res = inference_api::embed(&self.openapi_config, Some(request))
49            .await
50            .map_err(|e| PineconeError::from(e))?;
51
52        Ok(res.into())
53    }
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59    use crate::pinecone::PineconeClientConfig;
60    use httpmock::prelude::*;
61    use tokio;
62
63    #[tokio::test]
64    async fn test_embed() -> Result<(), PineconeError> {
65        let server = MockServer::start();
66
67        let mock = server.mock(|when, then| {
68            when.method(POST).path("/embed");
69            then.status(200)
70                .header("content-type", "application/json")
71                .body(
72                    r#"
73                    {
74                        "model": "multilingual-e5-large",
75                        "data": [
76                          {"values": [0.01849365234375, -0.003767013549804688, -0.037261962890625, 0.0222930908203125]}
77                        ],
78                        "usage": {"total_tokens": 1632}
79                    }
80                    "#,
81                );
82        });
83
84        let config = PineconeClientConfig {
85            control_plane_host: Some(server.base_url()),
86            ..Default::default()
87        };
88        let pinecone = config.client().expect("Failed to create Pinecone instance");
89
90        let response = pinecone
91            .embed("multilingual-e5-large", None, &vec!["Hello, world!"])
92            .await
93            .expect("Failed to embed");
94
95        mock.assert();
96
97        assert_eq!(response.model, "multilingual-e5-large");
98        assert_eq!(response.data.len(), 1);
99        assert_eq!(response.usage.total_tokens, 1632);
100
101        Ok(())
102    }
103
104    #[tokio::test]
105    async fn test_embed_invalid_arguments() -> Result<(), PineconeError> {
106        let server = MockServer::start();
107
108        let mock = server.mock(|when, then| {
109            when.method(POST).path("/embed");
110            then.status(400)
111                .header("content-type", "application/json")
112                .body(
113                    r#"
114                    {
115                        "error": {
116                          "code": "INVALID_ARGUMENT",
117                          "message": "Invalid parameter value input_type='bad-parameter' for model 'multilingual-e5-large', must be one of [query, passage]"
118                        },
119                        "status": 400
120                      }
121                    "#,
122                );
123        });
124
125        let config = PineconeClientConfig {
126            control_plane_host: Some(server.base_url()),
127            ..Default::default()
128        };
129        let pinecone = config.client().expect("Failed to create Pinecone instance");
130
131        let parameters = EmbedRequestParameters {
132            input_type: Some("bad-parameter".to_string()),
133            truncate: Some("bad-parameter".to_string()),
134        };
135
136        let _ = pinecone
137            .embed(
138                "multilingual-e5-large",
139                Some(parameters),
140                &vec!["Hello, world!"],
141            )
142            .await
143            .expect_err("Expected to fail embedding with invalid arguments");
144
145        mock.assert();
146
147        Ok(())
148    }
149}