Skip to main content

agentkit_provider_vllm/
lib.rs

1//! vLLM model adapter for the agentkit agent loop.
2//!
3//! This crate provides [`VllmAdapter`] and [`VllmConfig`] for connecting
4//! the agent loop to a [vLLM](https://docs.vllm.ai) server via its
5//! OpenAI-compatible chat completions endpoint. It is built on the generic
6//! [`agentkit_adapter_completions`] crate.
7//!
8//! An API key is optional — vLLM servers can run with or without authentication.
9//!
10//! # Quick start
11//!
12//! ```rust,ignore
13//! use agentkit_loop::{Agent, SessionConfig};
14//! use agentkit_provider_vllm::{VllmAdapter, VllmConfig};
15//!
16//! #[tokio::main]
17//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
18//!     let config = VllmConfig::new("meta-llama/Llama-3.1-8B-Instruct");
19//!     let adapter = VllmAdapter::new(config)?;
20//!
21//!     let agent = Agent::builder()
22//!         .model(adapter)
23//!         .build()?;
24//!
25//!     let mut driver = agent
26//!         .start(SessionConfig::new("demo"))
27//!         .await?;
28//!     Ok(())
29//! }
30//! ```
31
32use agentkit_adapter_completions::{
33    CompletionsAdapter, CompletionsError, CompletionsProvider, CompletionsSession, CompletionsTurn,
34};
35use agentkit_loop::{LoopError, ModelAdapter, SessionConfig};
36use async_trait::async_trait;
37use serde::Serialize;
38use thiserror::Error;
39
40const DEFAULT_ENDPOINT: &str = "http://localhost:8000/v1/chat/completions";
41
42/// Configuration for connecting to a vLLM server.
43///
44/// An API key is only required if the vLLM server was started with
45/// `--api-key`. Build one with [`VllmConfig::new`] for explicit values,
46/// or [`VllmConfig::from_env`] to read from environment variables.
47///
48/// # Example
49///
50/// ```rust,no_run
51/// use agentkit_provider_vllm::VllmConfig;
52///
53/// let config = VllmConfig::new("meta-llama/Llama-3.1-8B-Instruct")
54///     .with_base_url("http://gpu-server:8000/v1/chat/completions")
55///     .with_temperature(0.0);
56/// ```
57#[derive(Clone, Debug)]
58pub struct VllmConfig {
59    /// HuggingFace model identifier served by the vLLM instance,
60    /// e.g. `"meta-llama/Llama-3.1-8B-Instruct"`.
61    pub model: String,
62    /// Chat completions endpoint URL. Defaults to `http://localhost:8000/v1/chat/completions`.
63    pub base_url: String,
64    /// Optional API key, required only if the vLLM server enforces authentication.
65    pub api_key: Option<String>,
66    /// Sampling temperature (0.0 = deterministic, higher = more creative).
67    pub temperature: Option<f32>,
68    /// Maximum number of completion tokens the model may generate.
69    pub max_completion_tokens: Option<u32>,
70    /// Nucleus sampling parameter.
71    pub top_p: Option<f32>,
72    /// Whether the model is allowed to emit multiple tool calls in a
73    /// single turn. Omitted from the request when `None`.
74    pub parallel_tool_calls: Option<bool>,
75    /// Request SSE streaming responses. Defaults to `true`.
76    pub streaming: bool,
77    /// Whether the loaded chat template enforces strict
78    /// `user`/`assistant` role alternation. Set to `true` for
79    /// Mistral-/Mixtral-/Llama-template models served via vLLM, which
80    /// otherwise return `Conversation roles must alternate
81    /// user/assistant/user/assistant/...`. See
82    /// <https://github.com/vllm-project/vllm/issues/6862>.
83    pub strict_alternating_roles: bool,
84}
85
86impl VllmConfig {
87    /// Creates a new configuration with the given model identifier.
88    pub fn new(model: impl Into<String>) -> Self {
89        Self {
90            model: model.into(),
91            base_url: DEFAULT_ENDPOINT.into(),
92            api_key: None,
93            temperature: None,
94            max_completion_tokens: None,
95            top_p: None,
96            parallel_tool_calls: None,
97            streaming: true,
98            strict_alternating_roles: false,
99        }
100    }
101
102    /// Overrides the default chat completions endpoint URL.
103    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
104        self.base_url = url.into();
105        self
106    }
107
108    /// Sets the API key for authenticated vLLM servers.
109    pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
110        self.api_key = Some(key.into());
111        self
112    }
113
114    /// Sets the sampling temperature (0.0 for deterministic output).
115    pub fn with_temperature(mut self, v: f32) -> Self {
116        self.temperature = Some(v);
117        self
118    }
119
120    /// Sets the maximum number of tokens the model may generate per turn.
121    pub fn with_max_completion_tokens(mut self, v: u32) -> Self {
122        self.max_completion_tokens = Some(v);
123        self
124    }
125
126    /// Sets the nucleus sampling parameter.
127    pub fn with_top_p(mut self, v: f32) -> Self {
128        self.top_p = Some(v);
129        self
130    }
131
132    /// Sets whether the model may emit multiple tool calls in a single turn.
133    pub fn with_parallel_tool_calls(mut self, flag: bool) -> Self {
134        self.parallel_tool_calls = Some(flag);
135        self
136    }
137
138    /// Toggles SSE streaming of model responses. Default: true.
139    pub fn with_streaming(mut self, flag: bool) -> Self {
140        self.streaming = flag;
141        self
142    }
143
144    /// Enable strict `user`/`assistant` role alternation for chat
145    /// templates that require it (notably Mistral, Mixtral, Llama). The
146    /// adapter merges adjacent user-role messages before sending. See
147    /// <https://github.com/vllm-project/vllm/issues/6862>.
148    pub fn with_strict_alternating_roles(mut self, flag: bool) -> Self {
149        self.strict_alternating_roles = flag;
150        self
151    }
152
153    /// Builds a configuration from environment variables.
154    ///
155    /// | Variable | Required | Default |
156    /// |---|---|---|
157    /// | `VLLM_MODEL` | yes | -- |
158    /// | `VLLM_BASE_URL` | no | `http://localhost:8000/v1/chat/completions` |
159    /// | `VLLM_API_KEY` | no | -- |
160    pub fn from_env() -> Result<Self, VllmError> {
161        let model = std::env::var("VLLM_MODEL").map_err(|_| VllmError::MissingEnv("VLLM_MODEL"))?;
162
163        let mut config = Self::new(model);
164
165        if let Ok(url) = std::env::var("VLLM_BASE_URL") {
166            config = config.with_base_url(url);
167        }
168        if let Ok(key) = std::env::var("VLLM_API_KEY") {
169            config = config.with_api_key(key);
170        }
171
172        Ok(config)
173    }
174}
175
176/// Request parameters serialized into the vLLM request body.
177#[derive(Clone, Debug, Serialize)]
178pub struct VllmRequestConfig {
179    pub model: String,
180    #[serde(skip_serializing_if = "Option::is_none")]
181    pub temperature: Option<f32>,
182    #[serde(skip_serializing_if = "Option::is_none")]
183    pub max_completion_tokens: Option<u32>,
184    #[serde(skip_serializing_if = "Option::is_none")]
185    pub top_p: Option<f32>,
186    #[serde(skip_serializing_if = "Option::is_none")]
187    pub parallel_tool_calls: Option<bool>,
188}
189
190/// The vLLM provider, implementing [`CompletionsProvider`].
191#[derive(Clone, Debug)]
192pub struct VllmProvider {
193    base_url: String,
194    api_key: Option<String>,
195    streaming: bool,
196    strict_alternating_roles: bool,
197    request_config: VllmRequestConfig,
198}
199
200impl From<VllmConfig> for VllmProvider {
201    fn from(config: VllmConfig) -> Self {
202        Self {
203            base_url: config.base_url,
204            api_key: config.api_key,
205            streaming: config.streaming,
206            strict_alternating_roles: config.strict_alternating_roles,
207            request_config: VllmRequestConfig {
208                model: config.model,
209                temperature: config.temperature,
210                max_completion_tokens: config.max_completion_tokens,
211                top_p: config.top_p,
212                parallel_tool_calls: config.parallel_tool_calls,
213            },
214        }
215    }
216}
217
218impl CompletionsProvider for VllmProvider {
219    type Config = VllmRequestConfig;
220
221    fn provider_name(&self) -> &str {
222        "vLLM"
223    }
224    fn endpoint_url(&self) -> &str {
225        &self.base_url
226    }
227    fn config(&self) -> &VllmRequestConfig {
228        &self.request_config
229    }
230
231    fn preprocess_request(
232        &self,
233        builder: agentkit_http::HttpRequestBuilder,
234    ) -> agentkit_http::HttpRequestBuilder {
235        let builder = builder.header(
236            "User-Agent",
237            concat!("agentkit-provider-vllm/", env!("CARGO_PKG_VERSION")),
238        );
239        match &self.api_key {
240            Some(key) => builder.bearer_auth(key),
241            None => builder,
242        }
243    }
244
245    fn streaming(&self) -> bool {
246        self.streaming
247    }
248
249    fn requires_alternating_roles(&self) -> bool {
250        self.strict_alternating_roles
251    }
252}
253
254/// Model adapter that connects the agentkit agent loop to a vLLM server.
255///
256/// # Example
257///
258/// ```rust,no_run
259/// use agentkit_loop::Agent;
260/// use agentkit_provider_vllm::{VllmAdapter, VllmConfig};
261///
262/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
263/// let adapter = VllmAdapter::new(
264///     VllmConfig::new("meta-llama/Llama-3.1-8B-Instruct"),
265/// )?;
266///
267/// let agent = Agent::builder()
268///     .model(adapter)
269///     .build()?;
270/// # Ok(())
271/// # }
272/// ```
273#[derive(Clone)]
274pub struct VllmAdapter(CompletionsAdapter<VllmProvider>);
275
276/// An active session with a vLLM server.
277pub type VllmSession = CompletionsSession<VllmProvider>;
278
279/// A completed turn from a vLLM server.
280pub type VllmTurn = CompletionsTurn;
281
282impl VllmAdapter {
283    /// Creates a new adapter from the given configuration.
284    pub fn new(config: VllmConfig) -> Result<Self, VllmError> {
285        let provider = VllmProvider::from(config);
286        Ok(Self(CompletionsAdapter::new(provider)?))
287    }
288}
289
290#[async_trait]
291impl ModelAdapter for VllmAdapter {
292    type Session = VllmSession;
293
294    async fn start_session(&self, config: SessionConfig) -> Result<Self::Session, LoopError> {
295        self.0.start_session(config).await
296    }
297}
298
299/// Errors produced by the vLLM adapter.
300#[derive(Debug, Error)]
301pub enum VllmError {
302    /// A required environment variable is not set.
303    #[error("missing environment variable {0}")]
304    MissingEnv(&'static str),
305
306    /// An error from the generic completions adapter.
307    #[error(transparent)]
308    Completions(#[from] CompletionsError),
309}