1use std::env;
17use std::sync::Arc;
18use std::sync::LazyLock;
19use std::time::Duration;
20
21use async_stream::stream;
22use async_trait::async_trait;
23
24use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
25use dynamo_runtime::pipeline::{Error, ManyOut, SingleIn};
26use dynamo_runtime::protocols::annotated::Annotated;
27
28use crate::backend::ExecutionContext;
29use crate::preprocessor::PreprocessedRequest;
30use crate::protocols::common::llm_backend::LLMEngineOutput;
31use crate::protocols::openai::{
32 chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
33 completions::{prompt_to_string, CompletionResponse, NvCreateCompletionRequest},
34};
35use crate::types::openai::embeddings::NvCreateEmbeddingRequest;
36use crate::types::openai::embeddings::NvCreateEmbeddingResponse;
37
38#[derive(Debug, Clone)]
43pub struct MultiNodeConfig {
44 pub num_nodes: u32,
46 pub node_rank: u32,
48 pub leader_addr: String,
50}
51
52impl Default for MultiNodeConfig {
53 fn default() -> Self {
54 MultiNodeConfig {
55 num_nodes: 1,
56 node_rank: 0,
57 leader_addr: "".to_string(),
58 }
59 }
60}
61
62pub static TOKEN_ECHO_DELAY: LazyLock<Duration> = LazyLock::new(|| {
70 const DEFAULT_DELAY_MS: u64 = 10;
71
72 let delay_ms = env::var("DYN_TOKEN_ECHO_DELAY_MS")
73 .ok()
74 .and_then(|val| val.parse::<u64>().ok())
75 .unwrap_or(DEFAULT_DELAY_MS);
76
77 Duration::from_millis(delay_ms)
78});
79
80struct EchoEngineCore {}
84pub fn make_engine_core() -> ExecutionContext {
85 Arc::new(EchoEngineCore {})
86}
87
88#[async_trait]
89impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
90 for EchoEngineCore
91{
92 async fn generate(
93 &self,
94 incoming_request: SingleIn<PreprocessedRequest>,
95 ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
96 let (request, context) = incoming_request.into_parts();
97 let ctx = context.context();
98
99 let output = stream! {
100 for tok in request.token_ids {
101 tokio::time::sleep(*TOKEN_ECHO_DELAY).await;
102 yield delta_core(tok);
103 }
104 yield Annotated::from_data(LLMEngineOutput::stop());
105 };
106 Ok(ResponseStream::new(Box::pin(output), ctx))
107 }
108}
109
110fn delta_core(tok: u32) -> Annotated<LLMEngineOutput> {
111 let delta = LLMEngineOutput {
112 token_ids: vec![tok],
113 tokens: None,
114 text: None,
115 cum_log_probs: None,
116 log_probs: None,
117 finish_reason: None,
118 };
119 Annotated::from_data(delta)
120}
121
122struct EchoEngineFull {}
125
126pub struct EngineDispatcher<E> {
129 inner: E,
130}
131
132impl<E> EngineDispatcher<E> {
133 pub fn new(inner: E) -> Self {
134 EngineDispatcher { inner }
135 }
136}
137
138#[async_trait]
140pub trait StreamingEngine: Send + Sync {
141 async fn handle_completion(
142 &self,
143 req: SingleIn<NvCreateCompletionRequest>,
144 ) -> Result<ManyOut<Annotated<CompletionResponse>>, Error>;
145
146 async fn handle_chat(
147 &self,
148 req: SingleIn<NvCreateChatCompletionRequest>,
149 ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error>;
150}
151
152#[async_trait]
154pub trait EmbeddingEngine: Send + Sync {
155 async fn handle_embedding(
156 &self,
157 req: SingleIn<NvCreateEmbeddingRequest>,
158 ) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error>;
159}
160
161pub fn make_engine_full() -> Arc<dyn StreamingEngine> {
162 let engine = EchoEngineFull {};
163 let data = EngineDispatcher::new(engine);
164 Arc::new(data)
165}
166
167#[async_trait]
168impl
169 AsyncEngine<
170 SingleIn<NvCreateChatCompletionRequest>,
171 ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
172 Error,
173 > for EchoEngineFull
174{
175 async fn generate(
176 &self,
177 incoming_request: SingleIn<NvCreateChatCompletionRequest>,
178 ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
179 let (request, context) = incoming_request.transfer(());
180 let deltas = request.response_generator();
181 let ctx = context.context();
182 let req = request.inner.messages.into_iter().next_back().unwrap();
183
184 let prompt = match req {
185 async_openai::types::ChatCompletionRequestMessage::User(user_msg) => {
186 match user_msg.content {
187 async_openai::types::ChatCompletionRequestUserMessageContent::Text(prompt) => {
188 prompt
189 }
190 _ => anyhow::bail!("Invalid request content field, expected Content::Text"),
191 }
192 }
193 _ => anyhow::bail!("Invalid request type, expected User message"),
194 };
195
196 let output = stream! {
197 let mut id = 1;
198 for c in prompt.chars() {
199 tokio::time::sleep(*TOKEN_ECHO_DELAY).await;
201 let inner = deltas.create_choice(0, Some(c.to_string()), None, None);
202 let response = NvCreateChatCompletionStreamResponse {
203 inner,
204 };
205 yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
206 id += 1;
207 }
208
209 let inner = deltas.create_choice(0, None, Some(async_openai::types::FinishReason::Stop), None);
210 let response = NvCreateChatCompletionStreamResponse {
211 inner,
212 };
213 yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
214 };
215
216 Ok(ResponseStream::new(Box::pin(output), ctx))
217 }
218}
219
220#[async_trait]
221impl AsyncEngine<SingleIn<NvCreateCompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error>
222 for EchoEngineFull
223{
224 async fn generate(
225 &self,
226 incoming_request: SingleIn<NvCreateCompletionRequest>,
227 ) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> {
228 let (request, context) = incoming_request.transfer(());
229 let deltas = request.response_generator();
230 let ctx = context.context();
231 let chars_string = prompt_to_string(&request.inner.prompt);
232 let output = stream! {
233 let mut id = 1;
234 for c in chars_string.chars() {
235 tokio::time::sleep(*TOKEN_ECHO_DELAY).await;
236 let response = deltas.create_choice(0, Some(c.to_string()), None);
237 yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
238 id += 1;
239 }
240 let response = deltas.create_choice(0, None, Some("stop".to_string()));
241 yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
242
243 };
244
245 Ok(ResponseStream::new(Box::pin(output), ctx))
246 }
247}
248
249#[async_trait]
250impl
251 AsyncEngine<
252 SingleIn<NvCreateEmbeddingRequest>,
253 ManyOut<Annotated<NvCreateEmbeddingResponse>>,
254 Error,
255 > for EchoEngineFull
256{
257 async fn generate(
258 &self,
259 _incoming_request: SingleIn<NvCreateEmbeddingRequest>,
260 ) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
261 unimplemented!()
262 }
263}
264
265#[async_trait]
266impl<E> StreamingEngine for EngineDispatcher<E>
267where
268 E: AsyncEngine<
269 SingleIn<NvCreateCompletionRequest>,
270 ManyOut<Annotated<CompletionResponse>>,
271 Error,
272 > + AsyncEngine<
273 SingleIn<NvCreateChatCompletionRequest>,
274 ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
275 Error,
276 > + AsyncEngine<
277 SingleIn<NvCreateEmbeddingRequest>,
278 ManyOut<Annotated<NvCreateEmbeddingResponse>>,
279 Error,
280 > + Send
281 + Sync,
282{
283 async fn handle_completion(
284 &self,
285 req: SingleIn<NvCreateCompletionRequest>,
286 ) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> {
287 self.inner.generate(req).await
288 }
289
290 async fn handle_chat(
291 &self,
292 req: SingleIn<NvCreateChatCompletionRequest>,
293 ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
294 self.inner.generate(req).await
295 }
296}
297
298#[async_trait]
299impl<E> EmbeddingEngine for EngineDispatcher<E>
300where
301 E: AsyncEngine<
302 SingleIn<NvCreateEmbeddingRequest>,
303 ManyOut<Annotated<NvCreateEmbeddingResponse>>,
304 Error,
305 > + Send
306 + Sync,
307{
308 async fn handle_embedding(
309 &self,
310 req: SingleIn<NvCreateEmbeddingRequest>,
311 ) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
312 self.inner.generate(req).await
313 }
314}
315
316pub struct EmbeddingEngineAdapter(Arc<dyn EmbeddingEngine>);
317
318impl EmbeddingEngineAdapter {
319 pub fn new(engine: Arc<dyn EmbeddingEngine>) -> Self {
320 EmbeddingEngineAdapter(engine)
321 }
322}
323
324#[async_trait]
325impl
326 AsyncEngine<
327 SingleIn<NvCreateEmbeddingRequest>,
328 ManyOut<Annotated<NvCreateEmbeddingResponse>>,
329 Error,
330 > for EmbeddingEngineAdapter
331{
332 async fn generate(
333 &self,
334 req: SingleIn<NvCreateEmbeddingRequest>,
335 ) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
336 self.0.handle_embedding(req).await
337 }
338}
339
340pub struct StreamingEngineAdapter(Arc<dyn StreamingEngine>);
341
342impl StreamingEngineAdapter {
343 pub fn new(engine: Arc<dyn StreamingEngine>) -> Self {
344 StreamingEngineAdapter(engine)
345 }
346}
347
348#[async_trait]
349impl AsyncEngine<SingleIn<NvCreateCompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error>
350 for StreamingEngineAdapter
351{
352 async fn generate(
353 &self,
354 req: SingleIn<NvCreateCompletionRequest>,
355 ) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> {
356 self.0.handle_completion(req).await
357 }
358}
359
360#[async_trait]
361impl
362 AsyncEngine<
363 SingleIn<NvCreateChatCompletionRequest>,
364 ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
365 Error,
366 > for StreamingEngineAdapter
367{
368 async fn generate(
369 &self,
370 req: SingleIn<NvCreateChatCompletionRequest>,
371 ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
372 self.0.handle_chat(req).await
373 }
374}