elikoga_textsynth/completions/
logprob.rs

1//! Provides logprob api
2
3use serde::{Deserialize, Serialize};
4use serde_with::skip_serializing_none;
5use thiserror::Error;
6
7use crate::TextSynthClient;
8
9use super::Engine;
10
11/// Struct for a logprob request
12#[skip_serializing_none]
13#[derive(Serialize, Builder)]
14#[builder(setter(into))]
15#[builder(build_fn(validate = "Self::validate"))]
16pub struct Request {
17    /// If empty string, the context is set to the End-Of-Text token.
18    context: String,
19    /// Must be a non empty string.
20    continuation: String,
21}
22
23impl RequestBuilder {
24    fn validate(&self) -> Result<(), String> {
25        // n must be between 1 and 16
26        match &self.continuation {
27            Some(continuation) if continuation.is_empty() => {
28                return Err("Continuation must not be empty".to_string());
29            }
30            _ => {}
31        }
32        Ok(())
33    }
34}
35
36/// Struct for a logprob answer
37#[derive(Deserialize, Debug)]
38pub struct Response {
39    /// Logarithm of the probability of generation of continuation preceeded by
40    /// context. It corresponds to the sum of the logarithms of the
41    /// probabilities of the tokens of continuation. It is always <= 0.
42    pub logprob: f64,
43    /// Number of tokens in continuation.
44    pub num_tokens: u32,
45    /// true if continuation would be generated by greedy sampling from
46    /// continuation.
47    pub is_greedy: bool,
48    /// Indicate the total number of input tokens. It is useful to estimate the
49    /// number of compute resources used by the request.
50    pub input_tokens: u32,
51}
52
53#[derive(Error, Debug)]
54/// Error for a completion answer
55pub enum Error {
56    /// Serde error
57    #[error("Serde error: {0}")]
58    SerdeError(#[from] serde_json::Error),
59    /// Error from Reqwest
60    #[error("Reqwest error: {0}")]
61    RequestError(#[from] reqwest::Error),
62}
63
64impl TextSynthClient {
65    /// Perform a completion request
66    pub async fn logprob(&self, engine: &Engine, request: &Request) -> Result<Response, Error> {
67        let request_json = serde_json::to_string(&request)?;
68        let url = format!("{}/engines/{}/logprob", self.base_url, engine);
69        let response = self.client.post(&url).body(request_json).send().await?;
70        // println!("got response {:?}", response.text().await);
71        response.json().await.map_err(|e| e.into())
72    }
73}