elikoga_textsynth/completions/
logprob.rs1use serde::{Deserialize, Serialize};
4use serde_with::skip_serializing_none;
5use thiserror::Error;
6
7use crate::TextSynthClient;
8
9use super::Engine;
10
11#[skip_serializing_none]
13#[derive(Serialize, Builder)]
14#[builder(setter(into))]
15#[builder(build_fn(validate = "Self::validate"))]
16pub struct Request {
17 context: String,
19 continuation: String,
21}
22
23impl RequestBuilder {
24 fn validate(&self) -> Result<(), String> {
25 match &self.continuation {
27 Some(continuation) if continuation.is_empty() => {
28 return Err("Continuation must not be empty".to_string());
29 }
30 _ => {}
31 }
32 Ok(())
33 }
34}
35
36#[derive(Deserialize, Debug)]
38pub struct Response {
39 pub logprob: f64,
43 pub num_tokens: u32,
45 pub is_greedy: bool,
48 pub input_tokens: u32,
51}
52
53#[derive(Error, Debug)]
54pub enum Error {
56 #[error("Serde error: {0}")]
58 SerdeError(#[from] serde_json::Error),
59 #[error("Reqwest error: {0}")]
61 RequestError(#[from] reqwest::Error),
62}
63
64impl TextSynthClient {
65 pub async fn logprob(&self, engine: &Engine, request: &Request) -> Result<Response, Error> {
67 let request_json = serde_json::to_string(&request)?;
68 let url = format!("{}/engines/{}/logprob", self.base_url, engine);
69 let response = self.client.post(&url).body(request_json).send().await?;
70 response.json().await.map_err(|e| e.into())
72 }
73}