dynamo_llm/
engines.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::env;
5use std::sync::Arc;
6use std::sync::LazyLock;
7use std::time::Duration;
8
9use async_stream::stream;
10use async_trait::async_trait;
11
12use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
13use dynamo_runtime::pipeline::{Error, ManyOut, SingleIn};
14use dynamo_runtime::protocols::annotated::Annotated;
15
16use crate::protocols::openai::{
17    chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
18    completions::{NvCreateCompletionRequest, NvCreateCompletionResponse, prompt_to_string},
19};
20use crate::types::openai::embeddings::NvCreateEmbeddingRequest;
21use crate::types::openai::embeddings::NvCreateEmbeddingResponse;
22
23//
24// The engines are each in their own crate under `lib/engines`
25//
26
27#[derive(Debug, Clone)]
28pub struct MultiNodeConfig {
29    /// How many nodes / hosts we are using
30    pub num_nodes: u32,
31    /// Unique consecutive integer to identify this node
32    pub node_rank: u32,
33    /// host:port of head / control node
34    pub leader_addr: String,
35}
36
37impl Default for MultiNodeConfig {
38    fn default() -> Self {
39        MultiNodeConfig {
40            num_nodes: 1,
41            node_rank: 0,
42            leader_addr: "".to_string(),
43        }
44    }
45}
46
47//
48// Example echo engines
49//
50
51/// How long to sleep between echoed tokens.
52/// Default is 10ms which gives us 100 tok/s.
53/// Can be configured via the DYN_TOKEN_ECHO_DELAY_MS environment variable.
54pub static TOKEN_ECHO_DELAY: LazyLock<Duration> = LazyLock::new(|| {
55    const DEFAULT_DELAY_MS: u64 = 10;
56
57    let delay_ms = env::var("DYN_TOKEN_ECHO_DELAY_MS")
58        .ok()
59        .and_then(|val| val.parse::<u64>().ok())
60        .unwrap_or(DEFAULT_DELAY_MS);
61
62    Duration::from_millis(delay_ms)
63});
64
65/// Engine that accepts un-preprocessed requests and echos the prompt back as the response
66/// Useful for testing ingress such as service-http.
67struct EchoEngine {}
68
69/// Validate Engine that verifies request data
70pub struct ValidateEngine<E> {
71    inner: E,
72}
73
74impl<E> ValidateEngine<E> {
75    pub fn new(inner: E) -> Self {
76        Self { inner }
77    }
78}
79
80/// Engine that dispatches requests to either OpenAICompletions
81/// or OpenAIChatCompletions engine
82pub struct EngineDispatcher<E> {
83    inner: E,
84}
85
86impl<E> EngineDispatcher<E> {
87    pub fn new(inner: E) -> Self {
88        EngineDispatcher { inner }
89    }
90}
91
92/// Trait on request types that allows us to validate the data
93pub trait ValidateRequest {
94    fn validate(&self) -> Result<(), anyhow::Error>;
95}
96
97/// Trait that allows handling both completion and chat completions requests
98#[async_trait]
99pub trait StreamingEngine: Send + Sync {
100    async fn handle_completion(
101        &self,
102        req: SingleIn<NvCreateCompletionRequest>,
103    ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error>;
104
105    async fn handle_chat(
106        &self,
107        req: SingleIn<NvCreateChatCompletionRequest>,
108    ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error>;
109}
110
111/// Trait that allows handling embedding requests
112#[async_trait]
113pub trait EmbeddingEngine: Send + Sync {
114    async fn handle_embedding(
115        &self,
116        req: SingleIn<NvCreateEmbeddingRequest>,
117    ) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error>;
118}
119
120pub fn make_echo_engine() -> Arc<dyn StreamingEngine> {
121    let engine = EchoEngine {};
122    let data = EngineDispatcher::new(engine);
123    Arc::new(data)
124}
125
126#[async_trait]
127impl
128    AsyncEngine<
129        SingleIn<NvCreateChatCompletionRequest>,
130        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
131        Error,
132    > for EchoEngine
133{
134    async fn generate(
135        &self,
136        incoming_request: SingleIn<NvCreateChatCompletionRequest>,
137    ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
138        let (request, context) = incoming_request.transfer(());
139        let ctx = context.context();
140        let mut deltas = request.response_generator(ctx.id().to_string());
141        let Some(req) = request.inner.messages.into_iter().next_back() else {
142            anyhow::bail!("Empty chat messages in request");
143        };
144
145        let prompt = match req {
146            dynamo_async_openai::types::ChatCompletionRequestMessage::User(user_msg) => {
147                match user_msg.content {
148                    dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
149                        prompt,
150                    ) => prompt,
151                    _ => anyhow::bail!("Invalid request content field, expected Content::Text"),
152                }
153            }
154            _ => anyhow::bail!("Invalid request type, expected User message"),
155        };
156
157        let output = stream! {
158            let mut id = 1;
159            for c in prompt.chars() {
160                // we are returning characters not tokens, so there will be some postprocessing overhead
161                tokio::time::sleep(*TOKEN_ECHO_DELAY).await;
162                let response = deltas.create_choice(0, Some(c.to_string()), None, None, None);
163                yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
164                id += 1;
165            }
166
167            let response = deltas.create_choice(0, None, None, Some(dynamo_async_openai::types::FinishReason::Stop), None);
168            yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
169        };
170
171        Ok(ResponseStream::new(Box::pin(output), ctx))
172    }
173}
174
175#[async_trait]
176impl
177    AsyncEngine<
178        SingleIn<NvCreateCompletionRequest>,
179        ManyOut<Annotated<NvCreateCompletionResponse>>,
180        Error,
181    > for EchoEngine
182{
183    async fn generate(
184        &self,
185        incoming_request: SingleIn<NvCreateCompletionRequest>,
186    ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
187        let (request, context) = incoming_request.transfer(());
188        let ctx = context.context();
189        let deltas = request.response_generator(ctx.id().to_string());
190        let chars_string = prompt_to_string(&request.inner.prompt);
191        let output = stream! {
192            let mut id = 1;
193            for c in chars_string.chars() {
194                tokio::time::sleep(*TOKEN_ECHO_DELAY).await;
195                let response = deltas.create_choice(0, Some(c.to_string()), None, None);
196                yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
197                id += 1;
198            }
199            let response = deltas.create_choice(0, None, Some(dynamo_async_openai::types::CompletionFinishReason::Stop), None);
200            yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
201
202        };
203
204        Ok(ResponseStream::new(Box::pin(output), ctx))
205    }
206}
207
208#[async_trait]
209impl
210    AsyncEngine<
211        SingleIn<NvCreateEmbeddingRequest>,
212        ManyOut<Annotated<NvCreateEmbeddingResponse>>,
213        Error,
214    > for EchoEngine
215{
216    async fn generate(
217        &self,
218        _incoming_request: SingleIn<NvCreateEmbeddingRequest>,
219    ) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
220        unimplemented!()
221    }
222}
223
224#[async_trait]
225impl<E, Req, Resp> AsyncEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>, Error> for ValidateEngine<E>
226where
227    E: AsyncEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>, Error> + Send + Sync,
228    Req: ValidateRequest + Send + Sync + 'static,
229    Resp: Send + Sync + 'static,
230{
231    async fn generate(
232        &self,
233        incoming_request: SingleIn<Req>,
234    ) -> Result<ManyOut<Annotated<Resp>>, Error> {
235        let (request, context) = incoming_request.into_parts();
236
237        // Validate the request first
238        if let Err(validation_error) = request.validate() {
239            return Err(anyhow::anyhow!("Validation failed: {}", validation_error));
240        }
241
242        // Forward to inner engine if validation passes
243        let validated_request = SingleIn::rejoin(request, context);
244        self.inner.generate(validated_request).await
245    }
246}
247
248#[async_trait]
249impl<E> StreamingEngine for EngineDispatcher<E>
250where
251    E: AsyncEngine<
252            SingleIn<NvCreateCompletionRequest>,
253            ManyOut<Annotated<NvCreateCompletionResponse>>,
254            Error,
255        > + AsyncEngine<
256            SingleIn<NvCreateChatCompletionRequest>,
257            ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
258            Error,
259        > + AsyncEngine<
260            SingleIn<NvCreateEmbeddingRequest>,
261            ManyOut<Annotated<NvCreateEmbeddingResponse>>,
262            Error,
263        > + Send
264        + Sync,
265{
266    async fn handle_completion(
267        &self,
268        req: SingleIn<NvCreateCompletionRequest>,
269    ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
270        self.inner.generate(req).await
271    }
272
273    async fn handle_chat(
274        &self,
275        req: SingleIn<NvCreateChatCompletionRequest>,
276    ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
277        self.inner.generate(req).await
278    }
279}
280
281#[async_trait]
282impl<E> EmbeddingEngine for EngineDispatcher<E>
283where
284    E: AsyncEngine<
285            SingleIn<NvCreateEmbeddingRequest>,
286            ManyOut<Annotated<NvCreateEmbeddingResponse>>,
287            Error,
288        > + Send
289        + Sync,
290{
291    async fn handle_embedding(
292        &self,
293        req: SingleIn<NvCreateEmbeddingRequest>,
294    ) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
295        self.inner.generate(req).await
296    }
297}
298
299pub struct EmbeddingEngineAdapter(Arc<dyn EmbeddingEngine>);
300
301impl EmbeddingEngineAdapter {
302    pub fn new(engine: Arc<dyn EmbeddingEngine>) -> Self {
303        EmbeddingEngineAdapter(engine)
304    }
305}
306
307#[async_trait]
308impl
309    AsyncEngine<
310        SingleIn<NvCreateEmbeddingRequest>,
311        ManyOut<Annotated<NvCreateEmbeddingResponse>>,
312        Error,
313    > for EmbeddingEngineAdapter
314{
315    async fn generate(
316        &self,
317        req: SingleIn<NvCreateEmbeddingRequest>,
318    ) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
319        self.0.handle_embedding(req).await
320    }
321}
322
323pub struct StreamingEngineAdapter(Arc<dyn StreamingEngine>);
324
325impl StreamingEngineAdapter {
326    pub fn new(engine: Arc<dyn StreamingEngine>) -> Self {
327        StreamingEngineAdapter(engine)
328    }
329}
330
331#[async_trait]
332impl
333    AsyncEngine<
334        SingleIn<NvCreateCompletionRequest>,
335        ManyOut<Annotated<NvCreateCompletionResponse>>,
336        Error,
337    > for StreamingEngineAdapter
338{
339    async fn generate(
340        &self,
341        req: SingleIn<NvCreateCompletionRequest>,
342    ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
343        self.0.handle_completion(req).await
344    }
345}
346
347#[async_trait]
348impl
349    AsyncEngine<
350        SingleIn<NvCreateChatCompletionRequest>,
351        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
352        Error,
353    > for StreamingEngineAdapter
354{
355    async fn generate(
356        &self,
357        req: SingleIn<NvCreateChatCompletionRequest>,
358    ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
359        self.0.handle_chat(req).await
360    }
361}