cohere_rust/api/
classify.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use super::{EmbedModel, Truncate};
6
7#[derive(Serialize, Default, Debug)]
8pub struct ClassifyRequest<'input> {
9    /// An optional string representing the model you'd like to use.
10    #[serde(skip_serializing_if = "Option::is_none")]
11    pub model: Option<EmbedModel>,
12    /// An optional string representing the ID of a custom playground preset.
13    #[serde(skip_serializing_if = "Option::is_none")]
14    pub preset: Option<String>,
15    /// An array of strings that you would like to classify.
16    pub inputs: &'input [String],
17    /// An array of ClassifyExamples representing examples and the corresponding label.
18    pub examples: &'input [ClassifyExample<'input>],
19    /// Specify how the API will handle inputs longer than the maximum token length.
20    pub truncate: Option<Truncate>,
21}
22
23#[derive(Serialize, Debug)]
24pub struct ClassifyExample<'input> {
25    /// The text of the example.
26    pub text: &'input str,
27    /// The label that fits the example's text.
28    pub label: &'input str,
29}
30
31#[derive(Deserialize, Debug)]
32pub struct Confidence {
33    /// The label.
34    pub label: String,
35    /// The associated confidence with the label.
36    pub confidence: f64,
37}
38
39#[derive(Deserialize, Debug, PartialEq)]
40pub struct LabelProperties {
41    pub confidence: f64,
42}
43
44#[derive(Deserialize, Debug, PartialEq)]
45pub struct Classification {
46    pub id: String,
47    /// The top predicted label for the text.
48    pub prediction: String,
49    /// Confidence score for the top predicted label.
50    pub confidence: f32,
51    /// Confidence score for each label.
52    pub labels: HashMap<String, LabelProperties>,
53    /// The text that is being classified.
54    pub input: String,
55}
56
57#[derive(Deserialize, Debug)]
58pub(crate) struct ClassifyResponse {
59    pub classifications: Vec<Classification>,
60}