use crate::preprocessor::OpenAIPreprocessor;
use crate::request_template::RequestTemplate;
use crate::types::openai::chat_completions::{
NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
};
use anyhow::Context as _;
use dynamo_async_openai::types::{ChatCompletionMessageContent, FinishReason};
use dynamo_runtime::{DistributedRuntime, pipeline::Context, runtime::CancellationToken};
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use std::cmp;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
use crate::entrypoint::EngineConfig;
use crate::entrypoint::input::common;
const MAX_TOKENS: u32 = 8192;
const OUTPUT_FILENAME: &str = "output.jsonl";
#[derive(Serialize, Deserialize, Default, Debug)]
struct Entry {
text: String,
response: Option<String>,
#[serde(default)]
tokens_in: usize,
#[serde(default)]
tokens_out: usize,
#[serde(default)]
elapsed_ms: usize,
#[serde(default, skip_serializing_if = "Option::is_none")]
finish_reason: Option<FinishReason>,
#[serde(skip, default)]
request_id: usize,
}
pub async fn run(
distributed_runtime: DistributedRuntime,
input_jsonl: PathBuf,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
let cancel_token = distributed_runtime.primary_token();
if !input_jsonl.exists() || !input_jsonl.is_file() {
anyhow::bail!(
"Missing or not a file: {}. Should be a JSON Lines file.",
input_jsonl.display()
);
}
let mut prepared_engine = common::prepare_engine(distributed_runtime, engine_config).await?;
let pre_processor = if prepared_engine.has_tokenizer() {
Some(OpenAIPreprocessor::new(
prepared_engine.card.take().unwrap(),
)?)
} else {
None
};
let (done_entries_tx, done_entries_rx) = tokio::sync::mpsc::channel(64);
let dw_cancel_token = cancel_token.clone();
let mut output_file = input_jsonl.clone();
output_file.set_file_name(OUTPUT_FILENAME);
tokio::spawn(async move {
if let Err(err) = output_writer(dw_cancel_token, done_entries_rx, &output_file).await {
tracing::error!(%err, "Failed writing output to {}", output_file.display());
}
});
let service_name_ref = Arc::new(prepared_engine.service_name);
let tokens_in = Arc::new(AtomicU64::new(0));
let tokens_out = Arc::new(AtomicU64::new(0));
let mut handles = vec![];
let mut num_entries = 0;
let input_file = tokio::fs::File::open(&input_jsonl)
.await
.with_context(|| input_jsonl.display().to_string())?;
let buffered_input = tokio::io::BufReader::new(input_file);
tracing::info!("Timer start.");
let start = Instant::now();
let mut lines = buffered_input.lines();
let template: Option<Arc<RequestTemplate>> = prepared_engine.request_template.map(Arc::new);
while let Ok(Some(line)) = lines.next_line().await {
if cancel_token.is_cancelled() {
break;
}
if line.is_empty() {
continue;
}
let request_id = num_entries;
num_entries += 1;
let mut entry: Entry = match serde_json::from_str(&line) {
Ok(entry) => entry,
Err(err) => {
anyhow::bail!("Error parsing entry: '{line}'. {err}");
}
};
entry.request_id = request_id;
let engine = prepared_engine.engine.clone();
let pre_processor = pre_processor.clone();
let tokens_in = tokens_in.clone();
let tokens_out = tokens_out.clone();
let done_entries_tx = done_entries_tx.clone();
let service_name_ref = service_name_ref.clone();
let template_clone = template.clone();
let handle = tokio::spawn(async move {
let local_start = Instant::now();
let response = match evaluate(
request_id,
service_name_ref.as_str(),
engine,
&mut entry,
template_clone,
)
.await
{
Ok(r) => r,
Err(err) => {
tracing::error!(%err, entry.text, "Failed evaluating prompt");
return;
}
};
let local_elapsed = Instant::now() - local_start;
entry.elapsed_ms = local_elapsed.as_millis() as usize;
if let Some(pre) = pre_processor {
entry.tokens_in = match pre.tokenize(&entry.text) {
Ok(encoding) => encoding.token_ids().len(),
Err(err) => {
tracing::warn!(%err, entry.text, "Failed tokenizing prompt");
0
}
};
entry.tokens_out = match pre.tokenize(&response) {
Ok(encoding) => encoding.token_ids().len(),
Err(err) => {
tracing::warn!(%err, response, "Failed tokenizing response");
0
}
};
tokens_in.fetch_add(entry.tokens_in as u64, Ordering::Relaxed);
tokens_out.fetch_add(entry.tokens_out as u64, Ordering::Relaxed);
}
entry.response = Some(response);
let _ = done_entries_tx.send(entry).await;
});
handles.push(handle);
}
tokio::select! {
_ = cancel_token.cancelled() => {
return Ok(());
}
_ = futures::future::join_all(handles) => {
}
}
let elapsed = Instant::now() - start;
let elapsed_clean = Duration::from_millis(elapsed.as_millis() as u64);
let tokens_in = Arc::into_inner(tokens_in).unwrap().into_inner();
let tokens_out = Arc::into_inner(tokens_out).unwrap().into_inner();
tokio::time::sleep(Duration::from_millis(1)).await; tracing::info!(
"Ran {} files in {}. Tokens in: {} ({}/s). Tokens out: {} ({}/s)",
num_entries,
humantime::format_duration(elapsed_clean),
tokens_in,
tokens_in / cmp::max(elapsed.as_secs(), 1),
tokens_out,
tokens_out / cmp::max(elapsed.as_secs(), 1),
);
cancel_token.cancel();
Ok(())
}
async fn evaluate(
request_id: usize,
service_name: &str,
engine: OpenAIChatCompletionsStreamingEngine,
entry: &mut Entry,
template: Option<Arc<RequestTemplate>>,
) -> anyhow::Result<String> {
let user_message = dynamo_async_openai::types::ChatCompletionRequestMessage::User(
dynamo_async_openai::types::ChatCompletionRequestUserMessage {
content: dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
entry.text.clone(),
),
name: None,
},
);
let inner = dynamo_async_openai::types::CreateChatCompletionRequestArgs::default()
.messages(vec![user_message])
.model(
template
.as_ref()
.map_or_else(|| service_name.to_string(), |t| t.model.clone()),
)
.stream(true)
.max_completion_tokens(
template
.as_ref()
.map_or(MAX_TOKENS, |t| t.max_completion_tokens),
)
.temperature(template.as_ref().map_or(0.7, |t| t.temperature))
.build()?;
let req = NvCreateChatCompletionRequest {
inner,
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let mut stream = engine.generate(Context::new(req)).await?;
let mut output = String::new();
while let Some(item) = stream.next().await {
match (item.data.as_ref(), item.event.as_deref()) {
(Some(data), _) => {
let choice = data.choices.first();
let chat_comp = choice.as_ref().unwrap();
if let Some(c) = &chat_comp.delta.content {
match c {
ChatCompletionMessageContent::Text(text) => {
output += text;
}
ChatCompletionMessageContent::Parts(_) => {
}
}
}
entry.finish_reason = chat_comp.finish_reason;
if chat_comp.finish_reason.is_some() {
tracing::trace!(
request_id,
"finish reason: {:?}",
chat_comp.finish_reason.unwrap()
);
break;
}
}
(None, Some("error")) => {
tracing::error!(request_id, "the error case");
for err in item.comment.unwrap_or_default() {
tracing::error!(request_id, "Engine error: {err}");
}
}
(None, Some(annotation)) => {
tracing::debug!(request_id, "Annotation. {annotation}: {:?}", item.comment);
}
_ => {
unreachable!("Event from engine with no data, no error, no annotation.");
}
}
}
Ok(output)
}
async fn output_writer(
cancel_token: CancellationToken,
mut entries_rx: tokio::sync::mpsc::Receiver<Entry>,
output_file: &Path,
) -> anyhow::Result<()> {
let mut num_completed = 0;
let mut f = tokio::fs::File::create(output_file).await?;
loop {
let entry = tokio::select! {
_ = cancel_token.cancelled() => {
break;
}
maybe_entry = entries_rx.recv() => {
match maybe_entry {
Some(entry) => entry,
None => {break;}
}
}
};
let mut s = serde_json::to_string(&entry)?;
s.push('\n');
f.write_all(s.as_bytes()).await?;
num_completed += 1;
tracing::info!(entry.request_id, entry.tokens_out, "Saved {num_completed}");
}
Ok(())
}