tokn-convert 0.2.0-rc.3

Request and response conversion pipeline across tokn endpoint formats
Documentation
use super::super::error::{ConvertError, Result};
use super::super::ir::{IrDelta, IrResponse};
use super::event::SseEvent;
use crate::provider::Endpoint;
use eventsource_stream::Eventsource;
use futures_util::StreamExt;
use serde_json::Value;
use std::collections::BTreeMap;

#[derive(Default)]
struct ResponsesState {
  response_id: Option<String>,
  model: Option<String>,
  output_items: BTreeMap<usize, ResponseOutputItem>,
}

#[derive(Default)]
struct ResponseOutputItem {
  item_type: Option<String>,
  id: Option<String>,
  call_id: Option<String>,
  name: Option<String>,
  status: Option<String>,
  text: String,
  reasoning_summary: BTreeMap<usize, String>,
  reasoning_content: BTreeMap<usize, String>,
  arguments: String,
}

#[derive(Clone, Debug, Default)]
pub struct SseMetadata {
  pub response_id: Option<String>,
  pub model: Option<String>,
}

pub struct SseAccumulator {
  endpoint: Endpoint,
  response: IrResponse,
  responses: ResponsesState,
}

impl SseAccumulator {
  pub fn new(endpoint: Endpoint) -> Self {
    Self {
      endpoint,
      response: IrResponse::default(),
      responses: ResponsesState::default(),
    }
  }

  pub fn push_value(&mut self, value: &Value) -> Vec<IrDelta> {
    let deltas = match self.endpoint {
      Endpoint::ChatCompletions => {
        self.observe_chat_chunk(value);
        crate::chat::delta_from_chat_chunk(value)
      }
      Endpoint::Responses => self.delta_from_responses_event(value),
      Endpoint::Messages => {
        self.observe_messages_event(value);
        crate::messages::delta_from_messages_event(value)
      }
    };
    for delta in deltas.iter().cloned() {
      self.response.push_delta(delta);
    }
    deltas
  }

  pub fn finish(self) -> IrResponse {
    self.response
  }

  pub fn metadata(&self) -> SseMetadata {
    SseMetadata {
      response_id: self.responses.response_id.clone(),
      model: self.responses.model.clone(),
    }
  }

  fn delta_from_responses_event(&mut self, value: &Value) -> Vec<IrDelta> {
    self.observe_responses_response(value);
    self.observe_responses_output_item(value);
    self.observe_responses_part(value);
    let mut deltas = crate::responses::delta_from_responses_event(value);
    self.observe_responses_deltas(value, &deltas);
    for delta in &mut deltas {
      if let IrDelta::ToolCall { index, id, name, .. } = delta {
        if let Some(item) = self.responses.output_items.get(index) {
          if id.is_none() {
            *id = item.call_id.clone().or_else(|| item.id.clone());
          }
          if name.is_none() {
            *name = item.name.clone();
          }
        }
      }
    }
    deltas
  }

  fn observe_chat_chunk(&mut self, value: &Value) {
    if self.responses.response_id.is_none() {
      self.responses.response_id = value.get("id").and_then(Value::as_str).map(str::to_string);
    }
    if self.responses.model.is_none() {
      self.responses.model = value.get("model").and_then(Value::as_str).map(str::to_string);
    }
  }

  fn observe_messages_event(&mut self, value: &Value) {
    if !matches!(value.get("type").and_then(Value::as_str), Some("message_start")) {
      return;
    }
    let Some(message) = value.get("message") else {
      return;
    };
    if self.responses.response_id.is_none() {
      self.responses.response_id = message.get("id").and_then(Value::as_str).map(str::to_string);
    }
    if self.responses.model.is_none() {
      self.responses.model = message.get("model").and_then(Value::as_str).map(str::to_string);
    }
  }

  fn observe_responses_response(&mut self, value: &Value) {
    let Some(response) = value.get("response") else {
      return;
    };
    if self.responses.response_id.is_none() {
      self.responses.response_id = response.get("id").and_then(Value::as_str).map(str::to_string);
    }
    if self.responses.model.is_none() {
      self.responses.model = response.get("model").and_then(Value::as_str).map(str::to_string);
    }
    if let Some(usage) = crate::ir::usage_from_openai(response) {
      self.response.usage = Some(usage);
    }
  }

  fn observe_responses_output_item(&mut self, value: &Value) {
    match value.get("type").and_then(Value::as_str) {
      Some("response.output_item.added") | Some("response.output_item.done") => {}
      _ => return,
    }
    let Some(index) = value.get("output_index").and_then(Value::as_u64).map(|v| v as usize) else {
      return;
    };
    let Some(item) = value.get("item") else {
      return;
    };
    let entry = self.responses.output_items.entry(index).or_default();
    if entry.item_type.is_none() {
      entry.item_type = item.get("type").and_then(Value::as_str).map(str::to_string);
    }
    if entry.status.is_none() {
      entry.status = item.get("status").and_then(Value::as_str).map(str::to_string);
    }
    if entry.id.is_none() {
      entry.id = item
        .get("id")
        .or_else(|| value.get("item_id"))
        .and_then(Value::as_str)
        .map(str::to_string);
    }
    if entry.call_id.is_none() {
      entry.call_id = item.get("call_id").and_then(Value::as_str).map(str::to_string);
    }
    if entry.name.is_none() {
      entry.name = item.get("name").and_then(Value::as_str).map(str::to_string);
    }
    if let Some(arguments) = item
      .get("arguments")
      .or_else(|| item.get("input"))
      .and_then(Value::as_str)
    {
      entry.arguments = arguments.to_string();
    }
  }

  fn observe_responses_part(&mut self, value: &Value) {
    let Some(index) = value.get("output_index").and_then(Value::as_u64).map(|v| v as usize) else {
      return;
    };
    let Some(entry) = self.responses.output_items.get_mut(&index) else {
      return;
    };
    match value.get("type").and_then(Value::as_str) {
      Some("response.output_text.done") => {
        if let Some(text) = value.get("text").and_then(Value::as_str) {
          entry.text = text.to_string();
        }
      }
      Some("response.reasoning_summary_text.done") => {
        if let (Some(summary_index), Some(text)) = (
          value.get("summary_index").and_then(Value::as_u64).map(|v| v as usize),
          value.get("text").and_then(Value::as_str),
        ) {
          entry.reasoning_summary.insert(summary_index, text.to_string());
        }
      }
      Some("response.function_call_arguments.done") | Some("response.custom_tool_call_input.done") => {
        if let Some(arguments) = value
          .get("arguments")
          .or_else(|| value.get("input"))
          .and_then(Value::as_str)
        {
          entry.arguments = arguments.to_string();
        }
      }
      _ => {}
    }
  }

  fn observe_responses_deltas(&mut self, value: &Value, deltas: &[IrDelta]) {
    let Some(index) = value.get("output_index").and_then(Value::as_u64).map(|v| v as usize) else {
      return;
    };
    let entry = self.responses.output_items.entry(index).or_default();
    for delta in deltas {
      match delta {
        IrDelta::Text(text) => entry.text.push_str(text),
        IrDelta::Reasoning(text) => {
          let target = match value.get("type").and_then(Value::as_str) {
            Some("response.reasoning_summary_text.delta") => value
              .get("summary_index")
              .and_then(Value::as_u64)
              .map(|v| v as usize)
              .map(|i| entry.reasoning_summary.entry(i).or_default()),
            Some("response.reasoning_text.delta") => value
              .get("content_index")
              .and_then(Value::as_u64)
              .map(|v| v as usize)
              .map(|i| entry.reasoning_content.entry(i).or_default()),
            _ => None,
          };
          if let Some(buf) = target {
            buf.push_str(text);
          }
        }
        IrDelta::ToolCall { arguments_delta, .. } => entry.arguments.push_str(arguments_delta),
        _ => {}
      }
    }
  }
}

pub async fn accumulate(endpoint: Endpoint, resp: reqwest::Response) -> Result<IrResponse> {
  let mut acc = SseAccumulator::new(endpoint);
  let mut stream = resp.bytes_stream().eventsource();
  while let Some(item) = stream.next().await {
    let ev = item.map_err(|e| ConvertError::sse(e.to_string()))?;
    let event = SseEvent::from(ev);
    if event.is_done() {
      break;
    }
    let value = event
      .json
      .as_ref()
      .ok_or_else(|| ConvertError::sse("expected JSON SSE payload"))?;
    acc.push_value(value);
  }
  Ok(acc.finish())
}