use agentkit_adapter_completions::{
CompletionsAdapter, CompletionsError, CompletionsProvider, CompletionsSession, CompletionsTurn,
};
use agentkit_loop::{LoopError, ModelAdapter, SessionConfig};
use async_trait::async_trait;
use serde::Serialize;
use thiserror::Error;
const DEFAULT_ENDPOINT: &str = "http://localhost:8000/v1/chat/completions";
#[derive(Clone, Debug)]
pub struct VllmConfig {
pub model: String,
pub base_url: String,
pub api_key: Option<String>,
pub temperature: Option<f32>,
pub max_completion_tokens: Option<u32>,
pub top_p: Option<f32>,
pub parallel_tool_calls: Option<bool>,
pub streaming: bool,
pub strict_alternating_roles: bool,
}
impl VllmConfig {
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
base_url: DEFAULT_ENDPOINT.into(),
api_key: None,
temperature: None,
max_completion_tokens: None,
top_p: None,
parallel_tool_calls: None,
streaming: true,
strict_alternating_roles: false,
}
}
pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = Some(key.into());
self
}
pub fn with_temperature(mut self, v: f32) -> Self {
self.temperature = Some(v);
self
}
pub fn with_max_completion_tokens(mut self, v: u32) -> Self {
self.max_completion_tokens = Some(v);
self
}
pub fn with_top_p(mut self, v: f32) -> Self {
self.top_p = Some(v);
self
}
pub fn with_parallel_tool_calls(mut self, flag: bool) -> Self {
self.parallel_tool_calls = Some(flag);
self
}
pub fn with_streaming(mut self, flag: bool) -> Self {
self.streaming = flag;
self
}
pub fn with_strict_alternating_roles(mut self, flag: bool) -> Self {
self.strict_alternating_roles = flag;
self
}
pub fn from_env() -> Result<Self, VllmError> {
let model = std::env::var("VLLM_MODEL").map_err(|_| VllmError::MissingEnv("VLLM_MODEL"))?;
let mut config = Self::new(model);
if let Ok(url) = std::env::var("VLLM_BASE_URL") {
config = config.with_base_url(url);
}
if let Ok(key) = std::env::var("VLLM_API_KEY") {
config = config.with_api_key(key);
}
Ok(config)
}
}
#[derive(Clone, Debug, Serialize)]
pub struct VllmRequestConfig {
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_completion_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
}
#[derive(Clone, Debug)]
pub struct VllmProvider {
base_url: String,
api_key: Option<String>,
streaming: bool,
strict_alternating_roles: bool,
request_config: VllmRequestConfig,
}
impl From<VllmConfig> for VllmProvider {
fn from(config: VllmConfig) -> Self {
Self {
base_url: config.base_url,
api_key: config.api_key,
streaming: config.streaming,
strict_alternating_roles: config.strict_alternating_roles,
request_config: VllmRequestConfig {
model: config.model,
temperature: config.temperature,
max_completion_tokens: config.max_completion_tokens,
top_p: config.top_p,
parallel_tool_calls: config.parallel_tool_calls,
},
}
}
}
impl CompletionsProvider for VllmProvider {
type Config = VllmRequestConfig;
fn provider_name(&self) -> &str {
"vLLM"
}
fn endpoint_url(&self) -> &str {
&self.base_url
}
fn config(&self) -> &VllmRequestConfig {
&self.request_config
}
fn preprocess_request(
&self,
builder: agentkit_http::HttpRequestBuilder,
) -> agentkit_http::HttpRequestBuilder {
let builder = builder.header(
"User-Agent",
concat!("agentkit-provider-vllm/", env!("CARGO_PKG_VERSION")),
);
match &self.api_key {
Some(key) => builder.bearer_auth(key),
None => builder,
}
}
fn streaming(&self) -> bool {
self.streaming
}
fn requires_alternating_roles(&self) -> bool {
self.strict_alternating_roles
}
}
#[derive(Clone)]
pub struct VllmAdapter(CompletionsAdapter<VllmProvider>);
pub type VllmSession = CompletionsSession<VllmProvider>;
pub type VllmTurn = CompletionsTurn;
impl VllmAdapter {
pub fn new(config: VllmConfig) -> Result<Self, VllmError> {
let provider = VllmProvider::from(config);
Ok(Self(CompletionsAdapter::new(provider)?))
}
}
#[async_trait]
impl ModelAdapter for VllmAdapter {
type Session = VllmSession;
async fn start_session(&self, config: SessionConfig) -> Result<Self::Session, LoopError> {
self.0.start_session(config).await
}
}
#[derive(Debug, Error)]
pub enum VllmError {
#[error("missing environment variable {0}")]
MissingEnv(&'static str),
#[error(transparent)]
Completions(#[from] CompletionsError),
}