use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use futures::StreamExt;
use parking_lot::RwLock;
use tokio::sync::mpsc;
use tokio::time::timeout;
use tonic::transport::Channel;
use tracing::{debug, error, info};
use crate::api::remote::{
ConnectionSetupRequest, RequestTrait, ResponseTrait, ServerCheckRequest, ServerCheckResponse,
};
use crate::api::{
bi_request_stream_client::BiRequestStreamClient, request_client::RequestClient, Payload,
};
use crate::common::{DEFAULT_HEARTBEAT_INTERVAL_MS, DEFAULT_TIMEOUT_MS, LABEL_APP_NAME, LABEL_SOURCE, LABEL_SOURCE_SDK};
use crate::error::{BatataError, Result};
use crate::remote::ServerAddress;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ConnectionState {
Disconnected,
Connecting,
Connected,
Reconnecting,
}
pub type ServerPushHandler = Arc<dyn Fn(Payload) -> Option<Payload> + Send + Sync>;
pub struct GrpcConnection {
server_address: ServerAddress,
connection_id: Arc<RwLock<String>>,
state: Arc<RwLock<ConnectionState>>,
client_ip: String,
namespace: String,
app_name: String,
labels: HashMap<String, String>,
sender: Arc<RwLock<Option<mpsc::Sender<Payload>>>>,
push_handlers: Arc<RwLock<HashMap<String, ServerPushHandler>>>,
timeout_ms: u64,
#[allow(dead_code)]
heartbeat_interval_ms: u64,
}
impl GrpcConnection {
pub fn new(server_address: ServerAddress) -> Self {
Self {
server_address,
connection_id: Arc::new(RwLock::new(String::new())),
state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
client_ip: get_local_ip(),
namespace: String::new(),
app_name: String::new(),
labels: HashMap::new(),
sender: Arc::new(RwLock::new(None)),
push_handlers: Arc::new(RwLock::new(HashMap::new())),
timeout_ms: DEFAULT_TIMEOUT_MS,
heartbeat_interval_ms: DEFAULT_HEARTBEAT_INTERVAL_MS,
}
}
pub fn with_namespace(mut self, namespace: &str) -> Self {
self.namespace = namespace.to_string();
self
}
pub fn with_app_name(mut self, app_name: &str) -> Self {
self.app_name = app_name.to_string();
self
}
pub fn with_labels(mut self, labels: HashMap<String, String>) -> Self {
self.labels = labels;
self
}
pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
self.timeout_ms = timeout_ms;
self
}
pub fn connection_id(&self) -> String {
self.connection_id.read().clone()
}
pub fn state(&self) -> ConnectionState {
*self.state.read()
}
pub fn is_connected(&self) -> bool {
*self.state.read() == ConnectionState::Connected
}
pub fn register_push_handler(&self, message_type: &str, handler: ServerPushHandler) {
self.push_handlers
.write()
.insert(message_type.to_string(), handler);
}
pub async fn connect(&self) -> Result<()> {
{
let mut state = self.state.write();
if *state == ConnectionState::Connected {
return Ok(());
}
*state = ConnectionState::Connecting;
}
let endpoint = self.server_address.grpc_endpoint();
info!("Connecting to server: {}", endpoint);
let channel = Channel::from_shared(endpoint.clone())
.map_err(|e| BatataError::connection_error(format!("Invalid endpoint: {}", e)))?
.connect_timeout(Duration::from_millis(self.timeout_ms))
.connect()
.await?;
let connection_id = self.server_check(&channel).await?;
*self.connection_id.write() = connection_id.clone();
let mut bi_client = BiRequestStreamClient::new(channel.clone());
let (tx, rx) = mpsc::channel::<Payload>(100);
*self.sender.write() = Some(tx.clone());
let outbound = tokio_stream::wrappers::ReceiverStream::new(rx);
let response = bi_client.request_bi_stream(outbound).await?;
let mut inbound = response.into_inner();
self.send_connection_setup(&tx).await?;
let push_handlers = self.push_handlers.clone();
let state = self.state.clone();
let tx_clone = tx.clone();
tokio::spawn(async move {
while let Some(result) = inbound.next().await {
match result {
Ok(payload) => {
if let Some(metadata) = &payload.metadata {
let msg_type = &metadata.r#type;
debug!("Received message type: {}", msg_type);
let handler_opt = {
let handlers = push_handlers.read();
handlers.get(msg_type).cloned()
};
if let Some(handler) = handler_opt
&& let Some(response) = handler(payload)
&& let Err(e) = tx_clone.send(response).await
{
error!("Failed to send response: {}", e);
}
}
}
Err(e) => {
error!("Stream error: {}", e);
*state.write() = ConnectionState::Disconnected;
break;
}
}
}
});
*self.state.write() = ConnectionState::Connected;
info!(
"Connected to server, connection_id: {}",
self.connection_id()
);
Ok(())
}
pub async fn disconnect(&self) {
*self.state.write() = ConnectionState::Disconnected;
*self.sender.write() = None;
info!("Disconnected from server");
}
pub async fn request<Req, Resp>(&self, request: &Req) -> Result<Resp>
where
Req: RequestTrait + serde::Serialize,
Resp: for<'de> serde::Deserialize<'de> + Default + ResponseTrait,
{
if !self.is_connected() {
return Err(BatataError::ClientNotStarted);
}
let payload = request.to_payload(&self.client_ip);
let channel = Channel::from_shared(self.server_address.grpc_endpoint())
.map_err(|e| BatataError::connection_error(format!("Invalid endpoint: {}", e)))?
.connect_timeout(Duration::from_millis(self.timeout_ms))
.connect()
.await?;
let mut client = RequestClient::new(channel);
let response = timeout(
Duration::from_millis(self.timeout_ms),
client.request(payload),
)
.await
.map_err(|_| BatataError::Timeout {
timeout_ms: self.timeout_ms,
})??;
let payload = response.into_inner();
let resp: Resp = payload
.body
.as_ref()
.and_then(|body| serde_json::from_slice(&body.value).ok())
.unwrap_or_default();
if !resp.is_success() {
return Err(BatataError::server_error(resp.error_code(), resp.message()));
}
Ok(resp)
}
pub async fn send(&self, payload: Payload) -> Result<()> {
let sender = self.sender.read().clone();
let sender = sender.ok_or(BatataError::ClientNotStarted)?;
sender
.send(payload)
.await
.map_err(|e| BatataError::connection_error(format!("Failed to send: {}", e)))
}
async fn server_check(&self, channel: &Channel) -> Result<String> {
let mut client = RequestClient::new(channel.clone());
let request = ServerCheckRequest::new();
let payload = request.to_payload(&self.client_ip);
let response = timeout(
Duration::from_millis(self.timeout_ms),
client.request(payload),
)
.await
.map_err(|_| BatataError::Timeout {
timeout_ms: self.timeout_ms,
})??;
let payload = response.into_inner();
let resp: ServerCheckResponse = payload
.body
.as_ref()
.and_then(|body| serde_json::from_slice(&body.value).ok())
.unwrap_or_default();
if !resp.is_success() {
return Err(BatataError::server_error(resp.error_code(), resp.message()));
}
Ok(resp.connection_id)
}
async fn send_connection_setup(&self, sender: &mpsc::Sender<Payload>) -> Result<()> {
let mut labels = self.labels.clone();
labels.insert(LABEL_SOURCE.to_string(), LABEL_SOURCE_SDK.to_string());
if !self.app_name.is_empty() {
labels.insert(LABEL_APP_NAME.to_string(), self.app_name.clone());
}
let request = ConnectionSetupRequest::new()
.with_labels(labels)
.with_tenant(self.namespace.clone());
let payload = request.to_payload(&self.client_ip);
sender
.send(payload)
.await
.map_err(|e| BatataError::connection_error(format!("Failed to send setup: {}", e)))
}
}
fn get_local_ip() -> String {
if let Ok(addrs) = if_addrs::get_if_addrs() {
for iface in addrs {
if !iface.is_loopback()
&& let std::net::IpAddr::V4(ipv4) = iface.ip()
{
return ipv4.to_string();
}
}
}
"127.0.0.1".to_string()
}