1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
use reqwest::header;
use serde::Deserialize;

use crate::Client;
impl Client {
    pub async fn get_classifications(
        &self,
        string: String,
    ) -> Result<Vec<Classification>, Box<dyn std::error::Error>> {
        let mut headers = header::HeaderMap::new();
        headers.insert(
            "Authorization",
            format!("Bearer {}", self.config.key).parse()?,
        );
        headers.insert("Content-Type", "application/x-www-form-urlencoded".parse()?);

        let client = reqwest::Client::new();
        let res = client
            .post(format!(
                "https://api-inference.huggingface.co/models/{}",
                self.config.classification_model
            ))
            .headers(headers)
            .body(format!("inputs={}", string))
            .send()
            .await?
            .text()
            .await?;

        let classifications: Result<Vec<Classification>, serde_json::Error> =
            serde_json::from_str(&res);
        Ok(classifications?)
    }
}
#[derive(Deserialize, Debug)]
pub struct Classification {
    pub entity_group: String,
    pub score: f32,
    pub word: String,
    pub start: usize,
    pub end: usize,
}
#[cfg(test)]
mod tests {
    use crate::{Client, Config};

    #[tokio::test]
    async fn classification() {
        let mut config = Config::default();
        config.key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY not set");
        let client = Client::new(config);
        let classification = client
            .get_classifications("hello i am Yvonne Take i live in Amsterdam".to_string())
            .await;
        assert!(classification.is_ok());
    }
}