use std::{
borrow::Cow,
collections::HashMap,
fmt::{Debug, Display},
};
use dbn::{Compression, SType, Schema};
use hex::ToHex;
use sha2::{Digest, Sha256};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
use tracing::{debug, error, instrument};
use crate::{ApiKey, Error, USER_AGENT};
use super::{SlowReaderBehavior, Subscription};
pub fn determine_gateway(dataset: &str) -> String {
const DEFAULT_PORT: u16 = 13_000;
let dataset_subdomain: String = dataset.replace('.', "-").to_ascii_lowercase();
format!("{dataset_subdomain}.lsg.databento.com:{DEFAULT_PORT}")
}
pub struct Protocol<W> {
sender: W,
}
impl<W> Protocol<W>
where
W: AsyncWriteExt + Unpin,
{
pub fn new(sender: W) -> Self {
Self { sender }
}
#[instrument(skip(self, recver, key, options))]
pub async fn authenticate<R>(
&mut self,
recver: &mut R,
key: &ApiKey,
dataset: &str,
options: SessionOptions<'_>,
) -> crate::Result<String>
where
R: AsyncBufReadExt + Unpin,
{
let mut greeting = String::new();
recver.read_line(&mut greeting).await?;
greeting.pop();
debug!(greeting);
let mut response = String::new();
recver.read_line(&mut response).await?;
response.pop();
let challenge = Challenge::parse(&response).inspect_err(|_| {
error!(?response, "No CRAM challenge in response from gateway");
})?;
debug!(%challenge, "Received CRAM challenge");
let auth_req = AuthRequest::new(key, dataset, &challenge, options);
debug!(?auth_req, "Sending CRAM reply");
self.sender.write_all(auth_req.as_bytes()).await?;
response.clear();
recver.read_line(&mut response).await?;
if response.is_empty() {
error!("Received empty auth response");
} else {
debug!(
auth_resp = &response[..response.len() - 1],
"Received auth response"
);
}
response.pop();
let auth_resp = AuthResponse::parse(&response)?;
Ok(auth_resp
.session_id()
.map(ToOwned::to_owned)
.unwrap_or_default())
}
pub async fn subscribe(&mut self, sub: &Subscription) -> crate::Result<()> {
let Subscription {
schema,
stype_in,
start,
use_snapshot,
..
} = ⊂
if *use_snapshot && start.is_some() {
return Err(Error::BadArgument {
param_name: "use_snapshot",
desc: "cannot request snapshot with start time".to_owned(),
});
}
let start_nanos = sub.start.as_ref().map(|start| start.unix_timestamp_nanos());
let symbol_chunks = sub.symbols.to_chunked_api_string();
let last_chunk_idx = symbol_chunks.len() - 1;
for (i, sym_str) in symbol_chunks.into_iter().enumerate() {
let sub_req = SubRequest::new(
*schema,
*stype_in,
start_nanos,
*use_snapshot,
sub.id,
&sym_str,
i == last_chunk_idx,
);
debug!(?sub_req, "Sending subscription request");
self.sender.write_all(sub_req.as_bytes()).await?;
}
Ok(())
}
pub async fn start_session(&mut self) -> crate::Result<()> {
Ok(self.sender.write_all(StartRequest.as_bytes()).await?)
}
pub async fn shutdown(&mut self) -> crate::Result<()> {
Ok(self.sender.shutdown().await?)
}
pub fn into_inner(self) -> W {
self.sender
}
}
#[derive(Debug, Clone)]
pub struct Challenge<'a>(&'a str);
impl<'a> Challenge<'a> {
pub fn parse(response: &'a str) -> crate::Result<Self> {
if let Some(challenge) = response.strip_prefix("cram=") {
Ok(Self(challenge))
} else {
Err(Error::internal(
"no CRAM challenge in response from gateway",
))
}
}
}
impl Display for Challenge<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Clone, Debug)]
pub struct SessionOptions<'a> {
pub compression: Compression,
pub send_ts_out: bool,
pub heartbeat_interval_s: Option<i64>,
pub user_agent_ext: Option<&'a str>,
pub slow_reader_behavior: Option<SlowReaderBehavior>,
}
impl Default for SessionOptions<'_> {
fn default() -> Self {
Self {
compression: Compression::None,
send_ts_out: false,
heartbeat_interval_s: None,
user_agent_ext: None,
slow_reader_behavior: None,
}
}
}
fn parse_kv_pairs(s: &str) -> impl Iterator<Item = (&str, &str)> {
s.split('|').filter_map(|kvp| kvp.split_once('='))
}
pub trait RawApiMsg {
fn as_str(&self) -> &str;
fn as_bytes(&self) -> &[u8] {
self.as_str().as_bytes()
}
}
#[derive(Clone)]
pub struct AuthRequest(String);
impl AuthRequest {
pub fn new(
key: &ApiKey,
dataset: &str,
challenge: &Challenge,
options: SessionOptions,
) -> Self {
let challenge_key = format!("{challenge}|{}", key.0);
let mut hasher = Sha256::new();
hasher.update(challenge_key.as_bytes());
let hashed = hasher.finalize();
let bucket_id = key.bucket_id();
let encoded_response = hashed.encode_hex::<String>();
let send_ts_out = options.send_ts_out as u8;
let user_agent: Cow<'_, str> = match options.user_agent_ext {
Some(ext) => Cow::Owned(format!("{} {ext}", *USER_AGENT)),
None => Cow::Borrowed(&USER_AGENT),
};
let mut req = format!(
"auth={encoded_response}-{bucket_id}|dataset={dataset}|encoding=dbn|compression={compression}|ts_out={send_ts_out}|client={user_agent}",
compression = options.compression,
);
if let Some(heartbeat_interval_s) = options.heartbeat_interval_s {
req = format!("{req}|heartbeat_interval_s={heartbeat_interval_s}");
}
if let Some(slow_reader_behavior) = options.slow_reader_behavior {
req = format!("{req}|slow_reader_behavior={slow_reader_behavior}");
}
req.push('\n');
Self(req)
}
}
impl RawApiMsg for AuthRequest {
fn as_str(&self) -> &str {
self.0.as_str()
}
}
impl Debug for AuthRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", &self.0[..self.0.len() - 1])
}
}
pub struct AuthResponse<'a>(HashMap<&'a str, &'a str>);
impl<'a> AuthResponse<'a> {
pub fn parse(response: &'a str) -> crate::Result<Self> {
let auth_keys: HashMap<&'a str, &'a str> = parse_kv_pairs(response).collect();
if auth_keys.get("success").map(|v| *v != "1").unwrap_or(true) {
return Err(Error::Auth(
auth_keys
.get("error")
.map(|msg| (*msg).to_owned())
.unwrap_or_else(|| response.to_owned()),
));
}
Ok(Self(auth_keys))
}
pub fn session_id(&self) -> Option<&str> {
self.0.get("session_id").copied()
}
pub fn get_ref(&self) -> &HashMap<&'a str, &'a str> {
&self.0
}
}
#[derive(Clone)]
pub struct SubRequest(String);
impl SubRequest {
pub fn new(
schema: Schema,
stype_in: SType,
start_nanos: Option<i128>,
use_snapshot: bool,
id: Option<u32>,
symbols: &str,
is_last: bool,
) -> Self {
let use_snapshot = use_snapshot as u8;
let is_last = is_last as u8;
let mut args = format!(
"schema={schema}|stype_in={stype_in}|symbols={symbols}|snapshot={use_snapshot}|is_last={is_last}"
);
if let Some(start) = start_nanos {
args = format!("{args}|start={start}");
}
if let Some(id) = id {
args = format!("{args}|id={id}");
}
args.push('\n');
Self(args)
}
}
impl RawApiMsg for SubRequest {
fn as_str(&self) -> &str {
self.0.as_str()
}
}
impl Debug for SubRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", &self.0[..self.0.len() - 1])
}
}
#[derive(Debug, Clone, Copy)]
pub struct StartRequest;
impl RawApiMsg for StartRequest {
fn as_str(&self) -> &str {
"start_session\n"
}
}