openai-utils 0.5.1

utilities for using the openai api
Documentation
#![allow(dead_code)]
#![feature(type_name_of_val)]
#![feature(associated_type_defaults)]
#![feature(async_fn_in_trait)]

mod chat_completion;
mod chat_completion_delta;
mod chat_completion_request;
mod error;

use lazy_static::lazy_static;
#[allow(unused_imports)]
use log::{debug, error, info, trace, warn};
use std::{
    any::type_name_of_val,
    sync::{Arc, RwLock},
};

use schemars::{schema_for, JsonSchema};
use serde_derive::{Deserialize, Serialize};
use serde_json::Value;
pub use {
    chat_completion::ChatCompletion as Chat,
    chat_completion_delta::ChatCompletionDelta as ChatDelta, chat_completion_delta::DeltaReceiver,
    chat_completion_request::AiAgent,
    chat_completion_request::ChatCompletionRequest as ChatRequest,
};

lazy_static! {
    static ref OPENAI_API_KEY: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
    pub role: String,

    #[serde(skip_serializing_if = "Option::is_none")]
    pub content: Option<String>,

    #[serde(skip_serializing_if = "Option::is_none")]
    pub name: Option<String>,

    #[serde(skip_serializing_if = "Option::is_none")]
    pub function_call: Option<FunctionCall>,
}

impl Message {
    pub fn new(role: impl Into<String>) -> Self {
        Self {
            role: role.into(),
            content: None,
            name: None,
            function_call: None,
        }
    }

    pub fn with_content(mut self, content: impl Into<String>) -> Self {
        self.content = Some(content.into());
        self
    }

    pub fn with_name(mut self, name: impl Into<String>) -> Self {
        self.name = Some(name.into());
        self
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCall {
    pub name: String,
    pub arguments: String,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Function {
    pub name: String,

    #[serde(skip_serializing_if = "Option::is_none")]
    pub description: Option<String>,
    pub parameters: Value,
}

impl Function {
    pub fn from<FunctionArgs, Func, T>(function: &Func, function_name: &str) -> Self
    where
        FunctionArgs: JsonSchema,
        Func: FnMut(FunctionArgs) -> T,
    {
        let schema = schema_for!(FunctionArgs);
        let parameters = serde_json::to_value(schema)
            .unwrap_or_else(|_| panic!("Failed to serialize schema for function {}", function_name));
        Self {
            name: function_name.to_string(),
            description: match parameters.get("description") {
                Some(Value::String(s)) => Some(s.clone()),
                _ => None,
            },
            parameters,
        }
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
    pub index: i64,
    pub message: Message,
    pub finish_reason: String,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChoiceDelta {
    pub index: i64,
    pub delta: Delta,

    #[serde(default)]
    pub finish_reason: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Delta {
    pub role: Option<String>,

    pub content: Option<String>,

    pub function_call: Option<FunctionCallDelta>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCallDelta {
    pub name: Option<String>,
    pub arguments: Option<String>,
}

#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
    pub prompt_tokens: u64,
    pub completion_tokens: u64,
    pub total_tokens: u64,
}

pub fn api_key(api_key: String) {
    let mut key = OPENAI_API_KEY.write().unwrap();
    *key = Some(api_key);
}

#[derive(Default, Debug, Clone, JsonSchema)]
#[schemars(description = "this function takes no arguments")]
pub struct NoArgs {
    _unused: (),
}

pub fn calculate_message_tokens(message: &Message) -> usize {
    let bpe = tiktoken_rs::cl100k_base().unwrap();

    bpe.encode_with_special_tokens(&message.content.as_ref().unwrap_or(&"".to_string())).len()
}

pub fn calculate_tokens(s: &str) -> usize {
    let bpe = tiktoken_rs::cl100k_base().unwrap();

    bpe.encode_with_special_tokens(s).len()
}