rig_tei/
predict.rs

1use rig::http_client::{self, HttpClientExt};
2use serde::{Deserialize, Serialize};
3use serde_json::json;
4
5use super::client::Client;
6
7#[derive(Debug, Deserialize, Serialize, Clone)]
8pub struct LabelScore {
9    pub label: String,
10    pub score: f32,
11}
12
13#[derive(Debug, Deserialize, Serialize, Clone)]
14pub struct PredictResponse {
15    pub items: Vec<LabelScore>,
16}
17
18#[derive(Debug, Deserialize)]
19struct ItemsShape {
20    items: Vec<LabelScore>,
21}
22#[derive(Debug, Deserialize)]
23struct PredictionsShape {
24    predictions: Vec<LabelScore>,
25}
26#[derive(Debug, Deserialize)]
27struct ArraysShape {
28    labels: Vec<String>,
29    scores: Vec<f32>,
30}
31
32#[derive(Debug, Deserialize)]
33#[serde(untagged)]
34enum PredictResponseInternal {
35    Items(ItemsShape),
36    Predictions(PredictionsShape),
37    Arrays(ArraysShape),
38}
39
40#[derive(thiserror::Error, Debug)]
41pub enum PredictError {
42    #[error("http error: {0}")]
43    Http(#[from] http_client::Error),
44    #[error("provider error: {0}")]
45    Provider(String),
46    #[error("response error: {0}")]
47    Response(String),
48}
49
50impl Client<reqwest::Client> {
51    /// Predict/classify inputs using TEI router endpoint (customizable via ClientBuilder)
52    pub async fn predict(
53        &self,
54        inputs: impl IntoIterator<Item = String>,
55    ) -> Result<PredictResponse, PredictError> {
56        let inputs_vec: Vec<String> = inputs.into_iter().collect();
57        let body_value = if inputs_vec.len() == 1 {
58            json!({ "inputs": inputs_vec[0] })
59        } else {
60            json!({ "inputs": inputs_vec })
61        };
62
63        let body =
64            serde_json::to_vec(&body_value).map_err(|e| PredictError::Response(e.to_string()))?;
65
66        let req = self
67            .post_full(&self.endpoints.predict)
68            .header("Content-Type", "application/json")
69            .body(body)
70            .map_err(|e| PredictError::Http(e.into()))?;
71
72        let response = HttpClientExt::send(&self.http_client, req).await?;
73        if !response.status().is_success() {
74            let text = http_client::text(response).await?;
75            return Err(PredictError::Provider(text));
76        }
77
78        let bytes: Vec<u8> = response.into_body().await?;
79        let internal: PredictResponseInternal = serde_json::from_slice(&bytes).map_err(|e| {
80            PredictError::Response(format!("Failed to parse TEI predict response: {e}"))
81        })?;
82
83        let items = match internal {
84            PredictResponseInternal::Items(x) => x.items,
85            PredictResponseInternal::Predictions(x) => x.predictions,
86            PredictResponseInternal::Arrays(x) => {
87                if x.labels.len() != x.scores.len() {
88                    return Err(PredictError::Response(
89                        "labels and scores length mismatch".into(),
90                    ));
91                }
92                x.labels
93                    .into_iter()
94                    .zip(x.scores.into_iter())
95                    .map(|(label, score)| LabelScore { label, score })
96                    .collect()
97            }
98        };
99
100        Ok(PredictResponse { items })
101    }
102}