use std::net::SocketAddr;
use std::sync::Arc;
use async_trait::async_trait;
use bytes::Bytes;
use dashmap::DashMap;
use crabka_client_core::{ClientError, Connection, ConnectionOptions};
use crabka_metadata::voters::VoterSet;
use crate::error::RaftError;
use crate::kraft::transport::{PeerSender, api_key};
use crate::kraft::types::NodeId;
#[async_trait]
pub trait OutboundDialer: Send + Sync {
async fn dial(
&self,
target: NodeId,
addr: &str,
options: ConnectionOptions,
) -> Result<Connection, ClientError>;
}
pub struct PlaintextDialer;
#[async_trait]
impl OutboundDialer for PlaintextDialer {
async fn dial(
&self,
_target: NodeId,
addr: &str,
options: ConnectionOptions,
) -> Result<Connection, ClientError> {
let sock: SocketAddr = addr.parse().map_err(|e: std::net::AddrParseError| {
ClientError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("invalid raft peer address {addr:?}: {e}"),
))
})?;
Connection::connect(sock, options).await
}
}
fn controller_addr(voters: &VoterSet, id: NodeId) -> Option<String> {
let voter = voters.get(id)?;
let endpoint = voter
.endpoints
.iter()
.find(|e| e.name == "CONTROLLER")
.or_else(|| voter.endpoints.first())?;
Some(format!("{}:{}", endpoint.host, endpoint.port))
}
fn api_version_for(key: i16) -> i16 {
match key {
api_key::VOTE => 2,
api_key::BEGIN_QUORUM_EPOCH | api_key::END_QUORUM_EPOCH | api_key::FETCH_SNAPSHOT => 1,
api_key::FETCH => 17,
_ => 0,
}
}
pub(crate) struct RealPeerSender {
connections: DashMap<NodeId, Arc<Connection>>,
voters: VoterSet,
client_id: String,
dialer: Arc<dyn OutboundDialer>,
}
impl RealPeerSender {
pub(crate) fn new(
voters: VoterSet,
client_id: String,
dialer: Arc<dyn OutboundDialer>,
) -> Self {
Self {
connections: DashMap::new(),
voters,
client_id,
dialer,
}
}
async fn connect(&self, peer: NodeId) -> Result<Arc<Connection>, RaftError> {
if let Some(c) = self.connections.get(&peer) {
return Ok(c.value().clone());
}
let addr = controller_addr(&self.voters, peer).ok_or(RaftError::NotLeader {
current_leader: None,
})?;
let opts = ConnectionOptions {
client_id: self.client_id.clone(),
..ConnectionOptions::default()
};
let conn = Arc::new(self.dialer.dial(peer, &addr, opts).await?);
self.connections.insert(peer, conn.clone());
Ok(conn)
}
}
#[async_trait]
impl PeerSender for RealPeerSender {
async fn send(&self, peer: NodeId, key: i16, body: Bytes) -> Result<Bytes, RaftError> {
let conn = self.connect(peer).await?;
let version = api_version_for(key);
match conn.raw_request(key, version, body).await {
Ok(resp) => Ok(resp),
Err(e) => {
self.connections.remove(&peer);
Err(RaftError::Network(e))
}
}
}
}