agentkit_provider_vllm/lib.rs
1//! vLLM model adapter for the agentkit agent loop.
2//!
3//! This crate provides [`VllmAdapter`] and [`VllmConfig`] for connecting
4//! the agent loop to a [vLLM](https://docs.vllm.ai) server via its
5//! OpenAI-compatible chat completions endpoint. It is built on the generic
6//! [`agentkit_adapter_completions`] crate.
7//!
8//! An API key is optional — vLLM servers can run with or without authentication.
9//!
10//! # Quick start
11//!
12//! ```rust,ignore
13//! use agentkit_loop::{Agent, SessionConfig};
14//! use agentkit_provider_vllm::{VllmAdapter, VllmConfig};
15//!
16//! #[tokio::main]
17//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
18//! let config = VllmConfig::new("meta-llama/Llama-3.1-8B-Instruct");
19//! let adapter = VllmAdapter::new(config)?;
20//!
21//! let agent = Agent::builder()
22//! .model(adapter)
23//! .build()?;
24//!
25//! let mut driver = agent
26//! .start(SessionConfig::new("demo"))
27//! .await?;
28//! Ok(())
29//! }
30//! ```
31
32use agentkit_adapter_completions::{
33 CompletionsAdapter, CompletionsError, CompletionsProvider, CompletionsSession, CompletionsTurn,
34};
35use agentkit_loop::{LoopError, ModelAdapter, SessionConfig};
36use async_trait::async_trait;
37use serde::Serialize;
38use thiserror::Error;
39
40const DEFAULT_ENDPOINT: &str = "http://localhost:8000/v1/chat/completions";
41
42/// Configuration for connecting to a vLLM server.
43///
44/// An API key is only required if the vLLM server was started with
45/// `--api-key`. Build one with [`VllmConfig::new`] for explicit values,
46/// or [`VllmConfig::from_env`] to read from environment variables.
47///
48/// # Example
49///
50/// ```rust,no_run
51/// use agentkit_provider_vllm::VllmConfig;
52///
53/// let config = VllmConfig::new("meta-llama/Llama-3.1-8B-Instruct")
54/// .with_base_url("http://gpu-server:8000/v1/chat/completions")
55/// .with_temperature(0.0);
56/// ```
57#[derive(Clone, Debug)]
58pub struct VllmConfig {
59 /// HuggingFace model identifier served by the vLLM instance,
60 /// e.g. `"meta-llama/Llama-3.1-8B-Instruct"`.
61 pub model: String,
62 /// Chat completions endpoint URL. Defaults to `http://localhost:8000/v1/chat/completions`.
63 pub base_url: String,
64 /// Optional API key, required only if the vLLM server enforces authentication.
65 pub api_key: Option<String>,
66 /// Sampling temperature (0.0 = deterministic, higher = more creative).
67 pub temperature: Option<f32>,
68 /// Maximum number of completion tokens the model may generate.
69 pub max_completion_tokens: Option<u32>,
70 /// Nucleus sampling parameter.
71 pub top_p: Option<f32>,
72 /// Whether the model is allowed to emit multiple tool calls in a
73 /// single turn. Omitted from the request when `None`.
74 pub parallel_tool_calls: Option<bool>,
75 /// Request SSE streaming responses. Defaults to `true`.
76 pub streaming: bool,
77 /// Whether the loaded chat template enforces strict
78 /// `user`/`assistant` role alternation. Set to `true` for
79 /// Mistral-/Mixtral-/Llama-template models served via vLLM, which
80 /// otherwise return `Conversation roles must alternate
81 /// user/assistant/user/assistant/...`. See
82 /// <https://github.com/vllm-project/vllm/issues/6862>.
83 pub strict_alternating_roles: bool,
84}
85
86impl VllmConfig {
87 /// Creates a new configuration with the given model identifier.
88 pub fn new(model: impl Into<String>) -> Self {
89 Self {
90 model: model.into(),
91 base_url: DEFAULT_ENDPOINT.into(),
92 api_key: None,
93 temperature: None,
94 max_completion_tokens: None,
95 top_p: None,
96 parallel_tool_calls: None,
97 streaming: true,
98 strict_alternating_roles: false,
99 }
100 }
101
102 /// Overrides the default chat completions endpoint URL.
103 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
104 self.base_url = url.into();
105 self
106 }
107
108 /// Sets the API key for authenticated vLLM servers.
109 pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
110 self.api_key = Some(key.into());
111 self
112 }
113
114 /// Sets the sampling temperature (0.0 for deterministic output).
115 pub fn with_temperature(mut self, v: f32) -> Self {
116 self.temperature = Some(v);
117 self
118 }
119
120 /// Sets the maximum number of tokens the model may generate per turn.
121 pub fn with_max_completion_tokens(mut self, v: u32) -> Self {
122 self.max_completion_tokens = Some(v);
123 self
124 }
125
126 /// Sets the nucleus sampling parameter.
127 pub fn with_top_p(mut self, v: f32) -> Self {
128 self.top_p = Some(v);
129 self
130 }
131
132 /// Sets whether the model may emit multiple tool calls in a single turn.
133 pub fn with_parallel_tool_calls(mut self, flag: bool) -> Self {
134 self.parallel_tool_calls = Some(flag);
135 self
136 }
137
138 /// Toggles SSE streaming of model responses. Default: true.
139 pub fn with_streaming(mut self, flag: bool) -> Self {
140 self.streaming = flag;
141 self
142 }
143
144 /// Enable strict `user`/`assistant` role alternation for chat
145 /// templates that require it (notably Mistral, Mixtral, Llama). The
146 /// adapter merges adjacent user-role messages before sending. See
147 /// <https://github.com/vllm-project/vllm/issues/6862>.
148 pub fn with_strict_alternating_roles(mut self, flag: bool) -> Self {
149 self.strict_alternating_roles = flag;
150 self
151 }
152
153 /// Builds a configuration from environment variables.
154 ///
155 /// | Variable | Required | Default |
156 /// |---|---|---|
157 /// | `VLLM_MODEL` | yes | -- |
158 /// | `VLLM_BASE_URL` | no | `http://localhost:8000/v1/chat/completions` |
159 /// | `VLLM_API_KEY` | no | -- |
160 pub fn from_env() -> Result<Self, VllmError> {
161 let model = std::env::var("VLLM_MODEL").map_err(|_| VllmError::MissingEnv("VLLM_MODEL"))?;
162
163 let mut config = Self::new(model);
164
165 if let Ok(url) = std::env::var("VLLM_BASE_URL") {
166 config = config.with_base_url(url);
167 }
168 if let Ok(key) = std::env::var("VLLM_API_KEY") {
169 config = config.with_api_key(key);
170 }
171
172 Ok(config)
173 }
174}
175
176/// Request parameters serialized into the vLLM request body.
177#[derive(Clone, Debug, Serialize)]
178pub struct VllmRequestConfig {
179 pub model: String,
180 #[serde(skip_serializing_if = "Option::is_none")]
181 pub temperature: Option<f32>,
182 #[serde(skip_serializing_if = "Option::is_none")]
183 pub max_completion_tokens: Option<u32>,
184 #[serde(skip_serializing_if = "Option::is_none")]
185 pub top_p: Option<f32>,
186 #[serde(skip_serializing_if = "Option::is_none")]
187 pub parallel_tool_calls: Option<bool>,
188}
189
190/// The vLLM provider, implementing [`CompletionsProvider`].
191#[derive(Clone, Debug)]
192pub struct VllmProvider {
193 base_url: String,
194 api_key: Option<String>,
195 streaming: bool,
196 strict_alternating_roles: bool,
197 request_config: VllmRequestConfig,
198}
199
200impl From<VllmConfig> for VllmProvider {
201 fn from(config: VllmConfig) -> Self {
202 Self {
203 base_url: config.base_url,
204 api_key: config.api_key,
205 streaming: config.streaming,
206 strict_alternating_roles: config.strict_alternating_roles,
207 request_config: VllmRequestConfig {
208 model: config.model,
209 temperature: config.temperature,
210 max_completion_tokens: config.max_completion_tokens,
211 top_p: config.top_p,
212 parallel_tool_calls: config.parallel_tool_calls,
213 },
214 }
215 }
216}
217
218impl CompletionsProvider for VllmProvider {
219 type Config = VllmRequestConfig;
220
221 fn provider_name(&self) -> &str {
222 "vLLM"
223 }
224 fn endpoint_url(&self) -> &str {
225 &self.base_url
226 }
227 fn config(&self) -> &VllmRequestConfig {
228 &self.request_config
229 }
230
231 fn preprocess_request(
232 &self,
233 builder: agentkit_http::HttpRequestBuilder,
234 ) -> agentkit_http::HttpRequestBuilder {
235 let builder = builder.header(
236 "User-Agent",
237 concat!("agentkit-provider-vllm/", env!("CARGO_PKG_VERSION")),
238 );
239 match &self.api_key {
240 Some(key) => builder.bearer_auth(key),
241 None => builder,
242 }
243 }
244
245 fn streaming(&self) -> bool {
246 self.streaming
247 }
248
249 fn requires_alternating_roles(&self) -> bool {
250 self.strict_alternating_roles
251 }
252}
253
254/// Model adapter that connects the agentkit agent loop to a vLLM server.
255///
256/// # Example
257///
258/// ```rust,no_run
259/// use agentkit_loop::Agent;
260/// use agentkit_provider_vllm::{VllmAdapter, VllmConfig};
261///
262/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
263/// let adapter = VllmAdapter::new(
264/// VllmConfig::new("meta-llama/Llama-3.1-8B-Instruct"),
265/// )?;
266///
267/// let agent = Agent::builder()
268/// .model(adapter)
269/// .build()?;
270/// # Ok(())
271/// # }
272/// ```
273#[derive(Clone)]
274pub struct VllmAdapter(CompletionsAdapter<VllmProvider>);
275
276/// An active session with a vLLM server.
277pub type VllmSession = CompletionsSession<VllmProvider>;
278
279/// A completed turn from a vLLM server.
280pub type VllmTurn = CompletionsTurn;
281
282impl VllmAdapter {
283 /// Creates a new adapter from the given configuration.
284 pub fn new(config: VllmConfig) -> Result<Self, VllmError> {
285 let provider = VllmProvider::from(config);
286 Ok(Self(CompletionsAdapter::new(provider)?))
287 }
288}
289
290#[async_trait]
291impl ModelAdapter for VllmAdapter {
292 type Session = VllmSession;
293
294 async fn start_session(&self, config: SessionConfig) -> Result<Self::Session, LoopError> {
295 self.0.start_session(config).await
296 }
297}
298
299/// Errors produced by the vLLM adapter.
300#[derive(Debug, Error)]
301pub enum VllmError {
302 /// A required environment variable is not set.
303 #[error("missing environment variable {0}")]
304 MissingEnv(&'static str),
305
306 /// An error from the generic completions adapter.
307 #[error(transparent)]
308 Completions(#[from] CompletionsError),
309}