use std::{
collections::HashMap,
fmt::{Debug, Display},
};
use dbn::{SType, Schema};
use hex::ToHex;
use sha2::{Digest, Sha256};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
use tracing::{debug, error, instrument};
use crate::{ApiKey, Error};
use super::Subscription;
pub fn determine_gateway(dataset: &str) -> String {
const DEFAULT_PORT: u16 = 13_000;
let dataset_subdomain: String = dataset.replace('.', "-").to_ascii_lowercase();
format!("{dataset_subdomain}.lsg.databento.com:{DEFAULT_PORT}")
}
pub struct Protocol<W> {
sender: W,
}
impl<W> Protocol<W>
where
W: AsyncWriteExt + Unpin,
{
pub fn new(sender: W) -> Self {
Self { sender }
}
#[instrument(skip(self, recver, key))]
pub async fn authenticate<R>(
&mut self,
recver: &mut R,
key: &ApiKey,
dataset: &str,
send_ts_out: bool,
heartbeat_interval_s: Option<i64>,
) -> crate::Result<String>
where
R: AsyncBufReadExt + Unpin,
{
let mut greeting = String::new();
recver.read_line(&mut greeting).await?;
greeting.pop();
debug!(greeting);
let mut response = String::new();
recver.read_line(&mut response).await?;
response.pop();
let challenge = Challenge::parse(&response).inspect_err(|_| {
error!(?response, "No CRAM challenge in response from gateway");
})?;
debug!(%challenge, "Received CRAM challenge");
let auth_req =
AuthRequest::new(key, dataset, send_ts_out, heartbeat_interval_s, &challenge);
debug!(?auth_req, "Sending CRAM reply");
self.sender.write_all(auth_req.as_bytes()).await.unwrap();
response.clear();
recver.read_line(&mut response).await?;
debug!(
auth_resp = &response[..response.len() - 1],
"Received auth response"
);
response.pop();
let auth_resp = AuthResponse::parse(&response)?;
Ok(auth_resp
.0
.get("session_id")
.map(|sid| (*sid).to_owned())
.unwrap_or_default())
}
pub async fn subscribe(&mut self, sub: &Subscription) -> crate::Result<()> {
let Subscription {
schema,
stype_in,
start,
use_snapshot,
..
} = ⊂
if *use_snapshot && start.is_some() {
return Err(Error::BadArgument {
param_name: "use_snapshot".to_string(),
desc: "cannot request snapshot with start time".to_string(),
});
}
let start_nanos = sub.start.as_ref().map(|start| start.unix_timestamp_nanos());
let symbol_chunks = sub.symbols.to_chunked_api_string();
let last_chunk_idx = symbol_chunks.len() - 1;
for (i, sym_str) in symbol_chunks.into_iter().enumerate() {
let sub_req = SubRequest::new(
*schema,
*stype_in,
start_nanos,
*use_snapshot,
sub.id,
&sym_str,
i == last_chunk_idx,
);
debug!(?sub_req, "Sending subscription request");
self.sender.write_all(sub_req.as_bytes()).await?;
}
Ok(())
}
pub async fn start_session(&mut self) -> crate::Result<()> {
Ok(self.sender.write_all(StartRequest.as_bytes()).await?)
}
pub async fn shutdown(&mut self) -> crate::Result<()> {
Ok(self.sender.shutdown().await?)
}
pub fn into_inner(self) -> W {
self.sender
}
}
#[derive(Debug)]
pub struct Challenge<'a>(&'a str);
impl<'a> Challenge<'a> {
pub fn parse(response: &'a str) -> crate::Result<Self> {
if response.starts_with("cram=") {
Ok(Self(response.split_once('=').unwrap().1))
} else {
Err(Error::internal(
"no CRAM challenge in response from gateway",
))
}
}
}
impl Display for Challenge<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
pub struct AuthRequest(String);
impl AuthRequest {
pub fn new(
key: &ApiKey,
dataset: &str,
send_ts_out: bool,
heartbeat_interval_s: Option<i64>,
challenge: &Challenge,
) -> Self {
let challenge_key = format!("{challenge}|{}", key.0);
let mut hasher = Sha256::new();
hasher.update(challenge_key.as_bytes());
let hashed = hasher.finalize();
let bucket_id = key.bucket_id();
let encoded_response = hashed.encode_hex::<String>();
let send_ts_out = send_ts_out as u8;
let mut req =
format!("auth={encoded_response}-{bucket_id}|dataset={dataset}|encoding=dbn|ts_out={send_ts_out}|client=Rust {}", env!("CARGO_PKG_VERSION"));
if let Some(heartbeat_interval_s) = heartbeat_interval_s {
req = format!("{req}|heartbeat_interval_s={heartbeat_interval_s}");
}
req.push('\n');
Self(req)
}
pub fn as_str(&self) -> &str {
self.0.as_str()
}
pub fn as_bytes(&self) -> &[u8] {
self.0.as_bytes()
}
}
impl Debug for AuthRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", &self.0[..self.0.len() - 1])
}
}
pub struct AuthResponse<'a>(HashMap<&'a str, &'a str>);
impl<'a> AuthResponse<'a> {
pub fn parse(response: &'a str) -> crate::Result<Self> {
let auth_keys: HashMap<&'a str, &'a str> = response
.split('|')
.filter_map(|kvp| kvp.split_once('='))
.collect();
if auth_keys.get("success").map(|v| *v != "1").unwrap_or(true) {
return Err(Error::Auth(
auth_keys
.get("error")
.map(|msg| (*msg).to_owned())
.unwrap_or_else(|| response.to_owned()),
));
}
Ok(Self(auth_keys))
}
pub fn get_ref(&self) -> &HashMap<&'a str, &'a str> {
&self.0
}
}
pub struct SubRequest(String);
impl SubRequest {
pub fn new(
schema: Schema,
stype_in: SType,
start_nanos: Option<i128>,
use_snapshot: bool,
id: Option<u32>,
symbols: &str,
is_last: bool,
) -> Self {
let use_snapshot = use_snapshot as u8;
let is_last = is_last as u8;
let mut args = format!(
"schema={schema}|stype_in={stype_in}|symbols={symbols}|snapshot={use_snapshot}|is_last={is_last}"
);
if let Some(start) = start_nanos {
args = format!("{args}|start={start}");
}
if let Some(id) = id {
args = format!("{args}|id={id}");
}
args.push('\n');
Self(args)
}
pub fn as_str(&self) -> &str {
self.0.as_str()
}
pub fn as_bytes(&self) -> &[u8] {
self.0.as_bytes()
}
}
impl Debug for SubRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", &self.0[..self.0.len() - 1])
}
}
pub struct StartRequest;
impl StartRequest {
pub fn as_str(&self) -> &str {
"start_session\n"
}
pub fn as_bytes(&self) -> &[u8] {
self.as_str().as_bytes()
}
}