1use std::env;
5use std::sync::Arc;
6use std::sync::LazyLock;
7use std::time::Duration;
8
9use async_stream::stream;
10use async_trait::async_trait;
11
12use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
13use dynamo_runtime::pipeline::{Error, ManyOut, SingleIn};
14use dynamo_runtime::protocols::annotated::Annotated;
15
16use crate::protocols::openai::{
17 chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
18 completions::{NvCreateCompletionRequest, NvCreateCompletionResponse, prompt_to_string},
19};
20use crate::types::openai::embeddings::NvCreateEmbeddingRequest;
21use crate::types::openai::embeddings::NvCreateEmbeddingResponse;
22
23#[derive(Debug, Clone)]
28pub struct MultiNodeConfig {
29 pub num_nodes: u32,
31 pub node_rank: u32,
33 pub leader_addr: String,
35}
36
37impl Default for MultiNodeConfig {
38 fn default() -> Self {
39 MultiNodeConfig {
40 num_nodes: 1,
41 node_rank: 0,
42 leader_addr: "".to_string(),
43 }
44 }
45}
46
47pub static TOKEN_ECHO_DELAY: LazyLock<Duration> = LazyLock::new(|| {
55 const DEFAULT_DELAY_MS: u64 = 10;
56
57 let delay_ms = env::var("DYN_TOKEN_ECHO_DELAY_MS")
58 .ok()
59 .and_then(|val| val.parse::<u64>().ok())
60 .unwrap_or(DEFAULT_DELAY_MS);
61
62 Duration::from_millis(delay_ms)
63});
64
65struct EchoEngine {}
68
69pub struct ValidateEngine<E> {
71 inner: E,
72}
73
74impl<E> ValidateEngine<E> {
75 pub fn new(inner: E) -> Self {
76 Self { inner }
77 }
78}
79
80pub struct EngineDispatcher<E> {
83 inner: E,
84}
85
86impl<E> EngineDispatcher<E> {
87 pub fn new(inner: E) -> Self {
88 EngineDispatcher { inner }
89 }
90}
91
92pub trait ValidateRequest {
94 fn validate(&self) -> Result<(), anyhow::Error>;
95}
96
97#[async_trait]
99pub trait StreamingEngine: Send + Sync {
100 async fn handle_completion(
101 &self,
102 req: SingleIn<NvCreateCompletionRequest>,
103 ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error>;
104
105 async fn handle_chat(
106 &self,
107 req: SingleIn<NvCreateChatCompletionRequest>,
108 ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error>;
109}
110
111#[async_trait]
113pub trait EmbeddingEngine: Send + Sync {
114 async fn handle_embedding(
115 &self,
116 req: SingleIn<NvCreateEmbeddingRequest>,
117 ) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error>;
118}
119
120pub fn make_echo_engine() -> Arc<dyn StreamingEngine> {
121 let engine = EchoEngine {};
122 let data = EngineDispatcher::new(engine);
123 Arc::new(data)
124}
125
126#[async_trait]
127impl
128 AsyncEngine<
129 SingleIn<NvCreateChatCompletionRequest>,
130 ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
131 Error,
132 > for EchoEngine
133{
134 async fn generate(
135 &self,
136 incoming_request: SingleIn<NvCreateChatCompletionRequest>,
137 ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
138 let (request, context) = incoming_request.transfer(());
139 let ctx = context.context();
140 let mut deltas = request.response_generator(ctx.id().to_string());
141 let Some(req) = request.inner.messages.into_iter().next_back() else {
142 anyhow::bail!("Empty chat messages in request");
143 };
144
145 let prompt = match req {
146 dynamo_async_openai::types::ChatCompletionRequestMessage::User(user_msg) => {
147 match user_msg.content {
148 dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
149 prompt,
150 ) => prompt,
151 _ => anyhow::bail!("Invalid request content field, expected Content::Text"),
152 }
153 }
154 _ => anyhow::bail!("Invalid request type, expected User message"),
155 };
156
157 let output = stream! {
158 let mut id = 1;
159 for c in prompt.chars() {
160 tokio::time::sleep(*TOKEN_ECHO_DELAY).await;
162 let response = deltas.create_choice(0, Some(c.to_string()), None, None, None);
163 yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
164 id += 1;
165 }
166
167 let response = deltas.create_choice(0, None, None, Some(dynamo_async_openai::types::FinishReason::Stop), None);
168 yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
169 };
170
171 Ok(ResponseStream::new(Box::pin(output), ctx))
172 }
173}
174
175#[async_trait]
176impl
177 AsyncEngine<
178 SingleIn<NvCreateCompletionRequest>,
179 ManyOut<Annotated<NvCreateCompletionResponse>>,
180 Error,
181 > for EchoEngine
182{
183 async fn generate(
184 &self,
185 incoming_request: SingleIn<NvCreateCompletionRequest>,
186 ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
187 let (request, context) = incoming_request.transfer(());
188 let ctx = context.context();
189 let deltas = request.response_generator(ctx.id().to_string());
190 let chars_string = prompt_to_string(&request.inner.prompt);
191 let output = stream! {
192 let mut id = 1;
193 for c in chars_string.chars() {
194 tokio::time::sleep(*TOKEN_ECHO_DELAY).await;
195 let response = deltas.create_choice(0, Some(c.to_string()), None, None);
196 yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
197 id += 1;
198 }
199 let response = deltas.create_choice(0, None, Some(dynamo_async_openai::types::CompletionFinishReason::Stop), None);
200 yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
201
202 };
203
204 Ok(ResponseStream::new(Box::pin(output), ctx))
205 }
206}
207
208#[async_trait]
209impl
210 AsyncEngine<
211 SingleIn<NvCreateEmbeddingRequest>,
212 ManyOut<Annotated<NvCreateEmbeddingResponse>>,
213 Error,
214 > for EchoEngine
215{
216 async fn generate(
217 &self,
218 _incoming_request: SingleIn<NvCreateEmbeddingRequest>,
219 ) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
220 unimplemented!()
221 }
222}
223
224#[async_trait]
225impl<E, Req, Resp> AsyncEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>, Error> for ValidateEngine<E>
226where
227 E: AsyncEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>, Error> + Send + Sync,
228 Req: ValidateRequest + Send + Sync + 'static,
229 Resp: Send + Sync + 'static,
230{
231 async fn generate(
232 &self,
233 incoming_request: SingleIn<Req>,
234 ) -> Result<ManyOut<Annotated<Resp>>, Error> {
235 let (request, context) = incoming_request.into_parts();
236
237 if let Err(validation_error) = request.validate() {
239 return Err(anyhow::anyhow!("Validation failed: {}", validation_error));
240 }
241
242 let validated_request = SingleIn::rejoin(request, context);
244 self.inner.generate(validated_request).await
245 }
246}
247
248#[async_trait]
249impl<E> StreamingEngine for EngineDispatcher<E>
250where
251 E: AsyncEngine<
252 SingleIn<NvCreateCompletionRequest>,
253 ManyOut<Annotated<NvCreateCompletionResponse>>,
254 Error,
255 > + AsyncEngine<
256 SingleIn<NvCreateChatCompletionRequest>,
257 ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
258 Error,
259 > + AsyncEngine<
260 SingleIn<NvCreateEmbeddingRequest>,
261 ManyOut<Annotated<NvCreateEmbeddingResponse>>,
262 Error,
263 > + Send
264 + Sync,
265{
266 async fn handle_completion(
267 &self,
268 req: SingleIn<NvCreateCompletionRequest>,
269 ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
270 self.inner.generate(req).await
271 }
272
273 async fn handle_chat(
274 &self,
275 req: SingleIn<NvCreateChatCompletionRequest>,
276 ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
277 self.inner.generate(req).await
278 }
279}
280
281#[async_trait]
282impl<E> EmbeddingEngine for EngineDispatcher<E>
283where
284 E: AsyncEngine<
285 SingleIn<NvCreateEmbeddingRequest>,
286 ManyOut<Annotated<NvCreateEmbeddingResponse>>,
287 Error,
288 > + Send
289 + Sync,
290{
291 async fn handle_embedding(
292 &self,
293 req: SingleIn<NvCreateEmbeddingRequest>,
294 ) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
295 self.inner.generate(req).await
296 }
297}
298
299pub struct EmbeddingEngineAdapter(Arc<dyn EmbeddingEngine>);
300
301impl EmbeddingEngineAdapter {
302 pub fn new(engine: Arc<dyn EmbeddingEngine>) -> Self {
303 EmbeddingEngineAdapter(engine)
304 }
305}
306
307#[async_trait]
308impl
309 AsyncEngine<
310 SingleIn<NvCreateEmbeddingRequest>,
311 ManyOut<Annotated<NvCreateEmbeddingResponse>>,
312 Error,
313 > for EmbeddingEngineAdapter
314{
315 async fn generate(
316 &self,
317 req: SingleIn<NvCreateEmbeddingRequest>,
318 ) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
319 self.0.handle_embedding(req).await
320 }
321}
322
323pub struct StreamingEngineAdapter(Arc<dyn StreamingEngine>);
324
325impl StreamingEngineAdapter {
326 pub fn new(engine: Arc<dyn StreamingEngine>) -> Self {
327 StreamingEngineAdapter(engine)
328 }
329}
330
331#[async_trait]
332impl
333 AsyncEngine<
334 SingleIn<NvCreateCompletionRequest>,
335 ManyOut<Annotated<NvCreateCompletionResponse>>,
336 Error,
337 > for StreamingEngineAdapter
338{
339 async fn generate(
340 &self,
341 req: SingleIn<NvCreateCompletionRequest>,
342 ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
343 self.0.handle_completion(req).await
344 }
345}
346
347#[async_trait]
348impl
349 AsyncEngine<
350 SingleIn<NvCreateChatCompletionRequest>,
351 ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
352 Error,
353 > for StreamingEngineAdapter
354{
355 async fn generate(
356 &self,
357 req: SingleIn<NvCreateChatCompletionRequest>,
358 ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
359 self.0.handle_chat(req).await
360 }
361}