use std::collections::HashMap;
use std::io;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use parking_lot::Mutex;
use tokio::io::BufReader;
use tokio::net::TcpStream;
use tokio::sync::{Notify, oneshot};
use tokio::task::JoinHandle;
use tracing::{debug, warn};
use super::codec::{read_response, write_request};
use super::types::{Request, Response, VectorizerValue};
#[derive(Debug, thiserror::Error)]
pub enum RpcClientError {
#[error("network I/O error: {0}")]
Io(#[from] io::Error),
#[error("encode failed: {0}")]
Encode(#[from] rmp_serde::encode::Error),
#[error("server error: {0}")]
Server(String),
#[error("connection closed before response (reader task ended)")]
ConnectionClosed,
#[error("HELLO must succeed before any data-plane command can be issued")]
NotAuthenticated,
}
pub type Result<T> = std::result::Result<T, RpcClientError>;
#[derive(Debug, Clone, Default)]
pub struct HelloPayload {
pub token: Option<String>,
pub api_key: Option<String>,
pub client_name: Option<String>,
pub version: i64,
}
impl HelloPayload {
pub fn new(client_name: impl Into<String>) -> Self {
Self {
client_name: Some(client_name.into()),
version: 1,
..Default::default()
}
}
pub fn with_token(mut self, token: impl Into<String>) -> Self {
self.token = Some(token.into());
self.api_key = None;
self
}
pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self.token = None;
self
}
fn into_value(self) -> VectorizerValue {
let mut pairs = vec![(
VectorizerValue::Str("version".into()),
VectorizerValue::Int(self.version),
)];
if let Some(token) = self.token {
pairs.push((
VectorizerValue::Str("token".into()),
VectorizerValue::Str(token),
));
}
if let Some(api_key) = self.api_key {
pairs.push((
VectorizerValue::Str("api_key".into()),
VectorizerValue::Str(api_key),
));
}
if let Some(name) = self.client_name {
pairs.push((
VectorizerValue::Str("client_name".into()),
VectorizerValue::Str(name),
));
}
VectorizerValue::Map(pairs)
}
}
#[derive(Debug, Clone)]
pub struct HelloResponse {
pub server_version: String,
pub protocol_version: i64,
pub authenticated: bool,
pub admin: bool,
pub capabilities: Vec<String>,
}
impl HelloResponse {
fn parse(value: &VectorizerValue) -> Self {
let server_version = value
.map_get("server_version")
.and_then(|v| v.as_str())
.map(str::to_owned)
.unwrap_or_default();
let protocol_version = value
.map_get("protocol_version")
.and_then(|v| v.as_int())
.unwrap_or(0);
let authenticated = value
.map_get("authenticated")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let admin = value
.map_get("admin")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let capabilities = value
.map_get("capabilities")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(str::to_owned))
.collect()
})
.unwrap_or_default();
Self {
server_version,
protocol_version,
authenticated,
admin,
capabilities,
}
}
}
pub struct RpcClient {
writer: Arc<tokio::sync::Mutex<tokio::net::tcp::OwnedWriteHalf>>,
pending: Arc<Mutex<HashMap<u32, oneshot::Sender<Response>>>>,
next_id: AtomicU32,
reader_done: Arc<Notify>,
reader_task: Option<JoinHandle<()>>,
authenticated: Arc<Mutex<bool>>,
}
impl RpcClient {
pub async fn connect_url(url: &str) -> Result<Self> {
use super::endpoint::{Endpoint, parse_endpoint};
match parse_endpoint(url).map_err(|e| RpcClientError::Server(e.to_string()))? {
Endpoint::Rpc { host, port } => Self::connect(format!("{host}:{port}")).await,
Endpoint::Rest { url } => Err(RpcClientError::Server(format!(
"RpcClient cannot dial REST URL '{url}'; \
use the HTTP client (`vectorizer_sdk::VectorizerClient`) instead, \
or pass a `vectorizer://` URL"
))),
}
}
pub async fn connect(addr: impl tokio::net::ToSocketAddrs) -> Result<Self> {
let stream = TcpStream::connect(addr).await?;
let (read_half, write_half) = stream.into_split();
let mut reader = BufReader::new(read_half);
let pending: Arc<Mutex<HashMap<u32, oneshot::Sender<Response>>>> =
Arc::new(Mutex::new(HashMap::new()));
let reader_done = Arc::new(Notify::new());
let pending_for_reader = Arc::clone(&pending);
let done_for_reader = Arc::clone(&reader_done);
let reader_task = tokio::spawn(async move {
loop {
match read_response(&mut reader).await {
Ok(resp) => {
let sender = {
let mut p = pending_for_reader.lock();
p.remove(&resp.id)
};
match sender {
Some(tx) => {
let _ = tx.send(resp);
}
None => {
warn!(
id = resp.id,
"RpcClient received response with no pending caller — dropping"
);
}
}
}
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
debug!("RpcClient reader: clean EOF");
break;
}
Err(e) => {
warn!(error = %e, "RpcClient reader error — connection closed");
break;
}
}
}
let mut p = pending_for_reader.lock();
p.clear();
done_for_reader.notify_waiters();
});
Ok(Self {
writer: Arc::new(tokio::sync::Mutex::new(write_half)),
pending,
next_id: AtomicU32::new(1),
reader_done,
reader_task: Some(reader_task),
authenticated: Arc::new(Mutex::new(false)),
})
}
pub async fn hello(&self, payload: HelloPayload) -> Result<HelloResponse> {
let value = payload.into_value();
let result = self.raw_call("HELLO", vec![value]).await?;
let parsed = HelloResponse::parse(&result);
if parsed.authenticated {
*self.authenticated.lock() = true;
}
Ok(parsed)
}
pub async fn ping(&self) -> Result<String> {
let result = self.raw_call("PING", vec![]).await?;
result
.as_str()
.map(str::to_owned)
.ok_or_else(|| RpcClientError::Server("PING returned non-string payload".into()))
}
pub async fn call(
&self,
command: impl Into<String>,
args: Vec<VectorizerValue>,
) -> Result<VectorizerValue> {
let cmd = command.into();
let exempt = matches!(cmd.as_str(), "HELLO" | "PING");
if !exempt && !*self.authenticated.lock() {
return Err(RpcClientError::NotAuthenticated);
}
self.raw_call(cmd, args).await
}
async fn raw_call(
&self,
command: impl Into<String>,
args: Vec<VectorizerValue>,
) -> Result<VectorizerValue> {
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let (tx, rx) = oneshot::channel::<Response>();
{
let mut pending = self.pending.lock();
pending.insert(id, tx);
}
let req = Request {
id,
command: command.into(),
args,
};
{
let mut writer = self.writer.lock().await;
if let Err(e) = write_request(&mut *writer, &req).await {
self.pending.lock().remove(&id);
return Err(RpcClientError::from(e));
}
}
let resp = tokio::select! {
recv = rx => match recv {
Ok(resp) => resp,
Err(_) => return Err(RpcClientError::ConnectionClosed),
},
_ = self.reader_done.notified() => {
self.pending.lock().remove(&id);
return Err(RpcClientError::ConnectionClosed);
}
};
match resp.result {
Ok(value) => Ok(value),
Err(message) => Err(RpcClientError::Server(message)),
}
}
pub fn is_authenticated(&self) -> bool {
*self.authenticated.lock()
}
pub fn close(mut self) {
if let Some(handle) = self.reader_task.take() {
handle.abort();
}
}
}
impl Drop for RpcClient {
fn drop(&mut self) {
if let Some(handle) = self.reader_task.take() {
handle.abort();
}
}
}