use super::kafka_node::{ConnectionFactory, KafkaAddress};
use crate::{
connection::SinkConnection,
tls::{TlsConnector, TlsConnectorConfig},
};
use anyhow::{Context, Result};
use futures::stream::FuturesUnordered;
use kafka_protocol::protocol::StrBytes;
use metrics::{Histogram, histogram};
use rand::SeedableRng;
use rand::rngs::SmallRng;
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::Notify;
use tokio::sync::{mpsc, oneshot};
use tokio_stream::StreamExt;
pub(crate) mod connection;
mod create_token;
mod recreate_token_queue;
pub struct TokenRequest {
username: String,
response_tx: oneshot::Sender<DelegationToken>,
}
#[derive(Clone)]
pub struct TokenTask {
tx: mpsc::Sender<TokenRequest>,
}
impl TokenTask {
pub fn new(
mtls_connection_factory: ConnectionFactory,
mtls_port_contact_points: Vec<KafkaAddress>,
delegation_token_lifetime: Duration,
chain_name: &str,
) -> TokenTask {
let token_creation_time_metric = histogram!("shotover_kafka_delegation_token_creation_seconds",
"transform" => "KafkaSinkCluster", "chain" => chain_name.to_string());
let (tx, mut rx) = mpsc::channel::<TokenRequest>(1000);
tokio::spawn(async move {
loop {
match task(
&mut rx,
&mtls_connection_factory,
&mtls_port_contact_points,
delegation_token_lifetime,
&token_creation_time_metric,
)
.await
{
Ok(()) => {
break;
}
Err(err) => {
tracing::error!("Token task restarting due to failure, error was {err:?}");
}
}
}
});
TokenTask { tx }
}
pub async fn prefetch_token_for_user(&self, username: String) -> Result<()> {
let (response_tx, _response_rx) = oneshot::channel();
self.tx
.send(TokenRequest {
username,
response_tx,
})
.await
.context("Failed to request delegation token from token task")
}
pub async fn get_token_for_user(&self, username: String) -> Result<DelegationToken> {
let (response_tx, response_rx) = oneshot::channel();
self.tx
.send(TokenRequest {
username,
response_tx,
})
.await
.context("Failed to request delegation token from token task")?;
response_rx
.await
.context("Token task encountered an error before it could respond to request for token")
}
}
async fn task(
rx: &mut mpsc::Receiver<TokenRequest>,
mtls_connection_factory: &ConnectionFactory,
mtls_addresses: &[KafkaAddress],
delegation_token_lifetime: Duration,
token_creation_time_metric: &Histogram,
) -> Result<()> {
let mut rng = SmallRng::from_rng(&mut rand::rng());
let mut username_to_token = HashMap::new();
let mut recreate_queue =
recreate_token_queue::RecreateTokenQueue::new(delegation_token_lifetime);
let mut nodes = vec![];
loop {
tokio::select! {
biased;
username = recreate_queue.next() => {
let instant = Instant::now();
let token = create_token::create_token_with_timeout(
&mut nodes,
&mut rng,
mtls_connection_factory,
&username,
delegation_token_lifetime
).await
.with_context(|| format!("Failed to recreate delegation token for {username:?}"))?;
username_to_token.insert(username.clone(), token);
recreate_queue.push(username.clone());
let passed = instant.elapsed();
tracing::info!("Delegation token for {username:?} recreated in {passed:?}");
token_creation_time_metric.record(passed);
}
result = rx.recv() => {
if let Some(request) = result {
let instant = Instant::now();
if nodes.is_empty() {
let mut futures = FuturesUnordered::new();
for address in mtls_addresses {
futures.push(async move {
let connection = match mtls_connection_factory
.create_connection_unauthed(address)
.await
{
Ok(connection) => Some(connection),
Err(err) => {
tracing::error!("Token Task: Failed to create connection for {address:?} during nodes list init {err}");
None
}
};
Node {
connection,
address: address.clone(),
}
});
}
while let Some(node) = futures.next().await {
nodes.push(node);
}
}
let token = if let Some(token) = username_to_token.get(&request.username).cloned() {
token
} else {
let token = create_token::create_token_with_timeout(
&mut nodes,
&mut rng,
mtls_connection_factory,
&request.username,
delegation_token_lifetime,
).await
.with_context(|| format!("Failed to create delegation token for {:?}", request.username))?;
username_to_token.insert(request.username.clone(), token.clone());
recreate_queue.push(request.username.clone());
let passed = instant.elapsed();
tracing::info!("Delegation token for {:?} created in {passed:?}", request.username);
token_creation_time_metric.record(passed);
token
};
request.response_tx.send(token).ok();
}
else {
return Ok(())
}
}
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub struct AuthorizeScramOverMtlsConfig {
pub mtls_port_contact_points: Vec<String>,
pub tls: TlsConnectorConfig,
pub delegation_token_lifetime_seconds: u64,
}
impl AuthorizeScramOverMtlsConfig {
pub fn get_builder(
&self,
connect_timeout: Duration,
read_timeout: Option<Duration>,
chain_name: &str,
) -> Result<AuthorizeScramOverMtlsBuilder> {
let mtls_connection_factory = ConnectionFactory::new(
Some(TlsConnector::new(&self.tls)?),
connect_timeout,
read_timeout,
Arc::new(Notify::new()),
);
let contact_points: Result<Vec<_>> = self
.mtls_port_contact_points
.iter()
.map(|x| KafkaAddress::from_str(x))
.collect();
let delegation_token_lifetime = Duration::from_secs(self.delegation_token_lifetime_seconds);
Ok(AuthorizeScramOverMtlsBuilder {
token_task: TokenTask::new(
mtls_connection_factory,
contact_points?,
delegation_token_lifetime,
chain_name,
),
delegation_token_lifetime,
})
}
}
pub struct AuthorizeScramOverMtlsBuilder {
pub token_task: TokenTask,
pub delegation_token_lifetime: Duration,
}
impl AuthorizeScramOverMtlsBuilder {
pub fn build(&self) -> AuthorizeScramOverMtls {
AuthorizeScramOverMtls {
original_scram_state: OriginalScramState::WaitingOnServerFirst,
token_task: self.token_task.clone(),
username: String::new(),
delegation_token_lifetime: self.delegation_token_lifetime,
}
}
}
pub struct AuthorizeScramOverMtls {
pub original_scram_state: OriginalScramState,
token_task: TokenTask,
username: String,
pub delegation_token_lifetime: Duration,
}
impl AuthorizeScramOverMtls {
pub async fn set_username(&mut self, username: String) -> Result<()> {
self.token_task
.prefetch_token_for_user(username.clone())
.await?;
self.username = username;
Ok(())
}
pub async fn get_token_for_user(&self) -> Result<DelegationToken> {
if !matches!(self.original_scram_state, OriginalScramState::AuthSuccess) {
panic!("Cannot hand out tokens to a connection that has not authenticated yet.")
}
self.token_task
.get_token_for_user(self.username.clone())
.await
}
}
pub enum OriginalScramState {
WaitingOnServerFirst,
WaitingOnServerFinal,
AuthFailed,
AuthSuccess,
}
struct Node {
address: KafkaAddress,
connection: Option<SinkConnection>,
}
#[derive(Clone)]
pub struct DelegationToken {
pub token_id: String,
pub hmac: StrBytes,
}