openai_rs/endpoints/
classification.rs

1use std::borrow::Cow;
2use std::collections::HashMap;
3use hyper::{Body, Request};
4use serde::Serialize;
5use crate::endpoints::Model;
6use crate::endpoints::request::Endpoint;
7
8/// Given a query and a set of labeled examples, the model will predict the most likely label for the query.
9/// Useful as a drop-in replacement for any ML classification or text-to-label task.
10#[derive(Debug, Clone, Serialize)]
11pub struct Classification<'a> {
12    /// ID of the engine to use for completion. You can select one of ada, babbage, curie, or davinci.
13    #[serde(skip_serializing_if = "Option::is_none")]
14    pub model: Option<Model>,
15
16    /// Query to be classified.
17    pub query: Cow<'a, str>,
18
19    /// A list of examples with labels, in the following format:
20    /// `[["The movie is so interesting.", "Positive"], ["It is quite boring.", "Negative"], ...]`
21    /// All the label strings will be normalized to be capitalized.
22    /// You should specify either examples or file, but not both.
23    #[serde(skip_serializing_if = "Vec::is_empty")]
24    pub examples: Vec<[Cow<'a, str>; 2]>,
25
26    /// The ID of the uploaded file that contains training examples.
27    /// See upload file for how to upload a file of the desired format and purpose.
28    /// You should specify either examples or file, but not both.
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub file: Option<Cow<'a, str>>,
31
32    /// The set of categories being classified. If not specified, candidate labels will be
33    /// automatically collected from the examples you provide. All the label strings will be
34    /// normalized to be capitalized.
35    #[serde(skip_serializing_if = "Vec::is_empty")]
36    pub labels: Vec<Cow<'a, str>>,
37
38    /// ID of the engine to use for Search. You can select one of ada, babbage, curie, or davinci
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub search_model: Option<Model>,
41
42    /// What sampling temperature to use. Higher values mean the model will take more risks.
43    /// Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub temperature: Option<f32>,
46
47    /// Include the log probabilities on the logprobs most likely tokens, as well the chosen tokens.
48    /// For example, if logprobs is 5, the API will return a list of the 5 most likely tokens.
49    /// The API will always return the logprob of the sampled token,
50    /// so there may be up to logprobs+1 elements in the response.
51    /// The maximum value for logprobs is 5.
52    /// If you need more than this, please contact support@openai.com and describe your use case.
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub logprobs: Option<u32>,
55
56    /// The maximum number of examples to be ranked by Search when using file.
57    /// Setting it to a higher value leads to improved accuracy but with increased latency and cost.
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub max_examples: Option<u32>,
60
61    /// Modify the likelihood of specified tokens appearing in the completion.
62    /// Accepts a json object that maps tokens (specified by their token ID in the GPT tokenizer)
63    /// to an associated bias value from -100 to 100. You can use this tokenizer tool (which works
64    /// for both GPT-2 and GPT-3) to convert text to token IDs. Mathematically,
65    /// the bias is added to the logits generated by the model prior to sampling. The exact effect
66    /// will vary per model, but values between -1 and 1 should decrease or increase likelihood
67    /// of selection; values like -100 or 100 should result in a ban or exclusive selection
68    /// of the relevant token.
69    #[serde(skip_serializing_if = "HashMap::is_empty")]
70    pub logit_bias: HashMap<Cow<'a, str>, i32>,
71
72    /// If set to true, the returned JSON will include a "prompt" field containing the final prompt
73    /// that was used to request a completion. This is mainly useful for debugging purposes.
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub return_prompt: Option<bool>,
76
77    /// A special boolean flag for showing metadata. If set to true, each document entry in the
78    /// returned JSON will contain a "metadata" field. This flag only takes effect when file is set.
79    #[serde(skip_serializing_if = "Option::is_none")]
80    pub return_metadata: Option<bool>,
81
82    /// If set to true, the returned JSON will include a "prompt" field containing the final prompt
83    /// that was used to request a completion. This is mainly useful for debugging purposes.
84    #[serde(skip_serializing_if = "Vec::is_empty")]
85    pub expand: Vec<Cow<'a, str>>,
86
87    /// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub user: Option<Cow<'a, str>>,
90}
91
92impl Default for Classification<'_> {
93    fn default() -> Self {
94        Self {
95            model: None,
96            query: Cow::Borrowed(""),
97            examples: vec![],
98            file: None,
99            labels: vec![],
100            search_model: None,
101
102            temperature: None,
103
104            logprobs: None,
105            max_examples: None,
106
107            logit_bias: HashMap::new(),
108
109            return_prompt: None,
110            return_metadata: None,
111            expand: vec![],
112            user: None
113        }
114    }
115}
116
117impl Endpoint for Classification<'_> {
118    const ENDPOINT: &'static str = "https://api.openai.com/v1/classifications";
119
120    fn request(&self, auth_token: &str, _engine_id: Option<&str>) -> Request<Body> {
121        let serialized = serde_json::to_string(self)
122            .expect("Failed to serialize Classification");
123        let endpoint = Self::ENDPOINT.to_owned();
124        trace!("endpoint={}, serialized={}", endpoint, serialized);
125
126        super::request::post!(endpoint, auth_token, serialized)
127    }
128}