use std::env;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, Result};
use axum::Router;
use axum::extract::ws::{Message, WebSocket};
use axum::extract::{State, WebSocketUpgrade};
use axum::response::{Html, IntoResponse};
use axum::routing::get;
use kafkit_client::{
AdminConfig, AutoOffsetReset, ConsumerRecord, KafkaAdmin, KafkaClient, KafkaMessage,
KafkaProducer, NewTopic, RecordHeader,
};
use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;
use tokio::sync::broadcast;
const DEFAULT_BOOTSTRAP: &str = "localhost:9092";
const DEFAULT_TOPIC: &str = "kafkit.orders";
const DEFAULT_GROUP: &str = "kafkit-websocket-broadcaster";
const DEFAULT_BIND: &str = "127.0.0.1:8081";
const BROADCAST_CAPACITY: usize = 1024;
#[tokio::main]
async fn main() -> Result<()> {
let bootstrap =
env::var("KAFKIT_BOOTSTRAP_SERVERS").unwrap_or_else(|_| DEFAULT_BOOTSTRAP.to_owned());
let topic = env::var("KAFKIT_TOPIC").unwrap_or_else(|_| DEFAULT_TOPIC.to_owned());
let group_id = env::var("KAFKIT_GROUP_ID").unwrap_or_else(|_| DEFAULT_GROUP.to_owned());
let bind: SocketAddr = env::var("KAFKIT_WS_BIND")
.unwrap_or_else(|_| DEFAULT_BIND.to_owned())
.parse()
.context("KAFKIT_WS_BIND must be host:port, for example 127.0.0.1:8081")?;
ensure_topic(&bootstrap, &topic).await?;
let producer = KafkaClient::new(bootstrap.clone())
.topic(topic.clone())
.producer()
.with_client_id("kafkit-websocket-producer")
.connect()
.await
.context("failed to connect websocket producer")?;
let (broadcast_tx, _) = broadcast::channel(BROADCAST_CAPACITY);
let app_state = Arc::new(AppState {
broadcast_tx: broadcast_tx.clone(),
producer: Arc::new(producer),
});
tokio::spawn({
let bootstrap = bootstrap.clone();
let topic = topic.clone();
let group_id = group_id.clone();
async move {
if let Err(error) =
consume_and_broadcast(bootstrap, topic, group_id, broadcast_tx).await
{
eprintln!("background consumer stopped: {error:#}");
}
}
});
let app = Router::new()
.route("/", get(index))
.route("/health", get(health))
.route("/ws", get(ws_handler))
.with_state(app_state);
println!("Kafka bootstrap: {bootstrap}");
println!("Kafka topic: {topic}");
println!("Kafka group: {group_id}");
println!("WebSocket UI: http://{bind}");
println!("WebSocket endpoint: ws://{bind}/ws");
let listener = TcpListener::bind(bind)
.await
.with_context(|| format!("failed to bind {bind}"))?;
axum::serve(listener, app)
.await
.context("web server failed")?;
Ok(())
}
#[derive(Clone)]
struct AppState {
broadcast_tx: broadcast::Sender<String>,
producer: Arc<KafkaProducer>,
}
async fn health() -> &'static str {
"ok"
}
async fn index() -> Html<&'static str> {
Html(
r#"<!doctype html>
<html>
<head>
<meta charset="utf-8">
<title>kafkit-client websocket broadcast</title>
<style>
body { font-family: ui-monospace, SFMono-Regular, Menlo, monospace; margin: 2rem; max-width: 72rem; }
form { display: grid; gap: .75rem; grid-template-columns: 12rem 1fr auto; align-items: end; margin-bottom: 1rem; }
label { display: grid; gap: .25rem; }
input { font: inherit; padding: .5rem; }
button { font: inherit; padding: .55rem .9rem; cursor: pointer; }
pre { background: #111827; color: #e5e7eb; padding: 1rem; border-radius: 6px; min-height: 20rem; white-space: pre-wrap; }
</style>
</head>
<body>
<h1>Kafka records</h1>
<p>Send a message over the websocket. The server produces it to Kafka; the background consumer broadcasts it back to every subscriber.</p>
<form id="producer">
<label>Key <input id="key" value="browser-1"></label>
<label>Value <input id="value" value='{"source":"browser","status":"created"}'></label>
<button type="submit">Produce</button>
</form>
<pre id="log"></pre>
<script>
const log = document.getElementById("log");
const form = document.getElementById("producer");
const key = document.getElementById("key");
const value = document.getElementById("value");
const ws = new WebSocket(`ws://${location.host}/ws`);
ws.onopen = () => log.textContent += "connected\n";
ws.onmessage = (event) => {
log.textContent += JSON.stringify(JSON.parse(event.data), null, 2) + "\n\n";
log.scrollTop = log.scrollHeight;
};
ws.onclose = () => log.textContent += "closed\n";
ws.onerror = () => log.textContent += "websocket error\n";
form.onsubmit = (event) => {
event.preventDefault();
ws.send(JSON.stringify({
key: key.value || null,
value: value.value,
headers: { source: "browser-form" }
}));
};
</script>
</body>
</html>"#,
)
}
async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<AppState>>) -> impl IntoResponse {
let rx = state.broadcast_tx.subscribe();
ws.on_upgrade(move |socket| websocket_client(socket, rx, state.producer.clone()))
}
async fn websocket_client(
mut socket: WebSocket,
mut rx: broadcast::Receiver<String>,
producer: Arc<KafkaProducer>,
) {
if socket
.send(Message::Text(
r#"{"type":"connected","message":"subscribed to Kafka record stream; send JSON to produce records"}"#.into(),
))
.await
.is_err()
{
return;
}
loop {
tokio::select! {
message = rx.recv() => {
match message {
Ok(message) => {
if socket.send(Message::Text(message.into())).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Lagged(skipped)) => {
let warning = format!(
r#"{{"type":"lagged","skipped":{skipped}}}"#
);
if socket.send(Message::Text(warning.into())).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Closed) => break,
}
}
incoming = socket.recv() => {
match incoming {
Some(Ok(Message::Close(_))) | None => break,
Some(Ok(Message::Text(text))) => {
let reply = match produce_from_websocket(&producer, &text).await {
Ok(ack) => ack,
Err(error) => format!(r#"{{"type":"produce_error","message":{}}}"#, json_string(&error.to_string())),
};
if socket.send(Message::Text(reply.into())).await.is_err() {
break;
}
}
Some(Ok(Message::Binary(bytes))) => {
let text = String::from_utf8_lossy(&bytes);
let reply = match produce_from_websocket(&producer, &text).await {
Ok(ack) => ack,
Err(error) => format!(r#"{{"type":"produce_error","message":{}}}"#, json_string(&error.to_string())),
};
if socket.send(Message::Text(reply.into())).await.is_err() {
break;
}
}
Some(Ok(_)) => {}
Some(Err(_)) => break,
}
}
}
}
}
async fn produce_from_websocket(producer: &KafkaProducer, text: &str) -> Result<String> {
let input: ProduceCommand = serde_json::from_str(text)
.context("expected JSON like {\"key\":\"id\",\"value\":\"...\"}")?;
let mut message = match input.value {
Some(value) => KafkaMessage::new(value),
None => KafkaMessage::tombstone(),
};
if let Some(key) = input.key {
message = message.with_key(key);
}
if let Some(partition) = input.partition {
message = message.with_partition(partition);
}
if let Some(headers) = input.headers {
for (key, value) in headers {
message = message.with_header(RecordHeader::new(key, value));
}
}
let ack = producer
.send_message(message)
.await
.context("failed to produce websocket message")?;
Ok(format!(
r#"{{"type":"produced","topic":{},"partition":{},"base_offset":{}}}"#,
json_string(&ack.topic),
ack.partition,
ack.base_offset
))
}
async fn consume_and_broadcast(
bootstrap: String,
topic: String,
group_id: String,
broadcast_tx: broadcast::Sender<String>,
) -> Result<()> {
let consumer = KafkaClient::new(bootstrap)
.topic(topic)
.consumer(group_id)
.with_client_id("kafkit-websocket-broadcast-consumer")
.with_auto_offset_reset(AutoOffsetReset::Earliest)
.with_poll_timeout(Duration::from_millis(500))
.connect()
.await
.context("failed to connect background consumer")?;
loop {
let records = consumer
.poll()
.await
.context("background consumer poll failed")?;
if records.is_empty() {
continue;
}
for record in records.iter() {
let payload = serde_json::to_string(&BroadcastRecord::from(record))
.context("failed to encode websocket record")?;
let _ = broadcast_tx.send(payload);
}
consumer
.commit(&records)
.await
.context("background consumer commit failed")?;
}
}
async fn ensure_topic(bootstrap: &str, topic: &str) -> Result<()> {
let admin = KafkaAdmin::connect(AdminConfig::new(bootstrap))
.await
.context("failed to connect admin client")?;
let topics = admin.list_topics().await.context("failed to list topics")?;
if topics.iter().any(|listing| listing.name == topic) {
return Ok(());
}
admin
.create_topics([NewTopic::new(topic, 3, 1)])
.await
.with_context(|| format!("failed to create topic {topic}"))?;
Ok(())
}
#[derive(Deserialize)]
struct ProduceCommand {
key: Option<String>,
value: Option<String>,
partition: Option<i32>,
headers: Option<std::collections::BTreeMap<String, String>>,
}
#[derive(Serialize)]
struct BroadcastRecord {
r#type: &'static str,
topic: String,
partition: i32,
offset: i64,
timestamp: i64,
key: Option<String>,
value: Option<String>,
headers: Vec<BroadcastHeader>,
}
#[derive(Serialize)]
struct BroadcastHeader {
key: String,
value: Option<String>,
}
impl From<&ConsumerRecord> for BroadcastRecord {
fn from(record: &ConsumerRecord) -> Self {
Self {
r#type: "record",
topic: record.topic.clone(),
partition: record.partition,
offset: record.offset,
timestamp: record.timestamp,
key: record.key.as_deref().map(bytes_to_string),
value: record.value.as_deref().map(bytes_to_string),
headers: record
.headers
.iter()
.map(|header| BroadcastHeader {
key: header.key.clone(),
value: header.value.as_deref().map(bytes_to_string),
})
.collect(),
}
}
}
fn bytes_to_string(bytes: &[u8]) -> String {
String::from_utf8_lossy(bytes).into_owned()
}
fn json_string(value: &str) -> String {
serde_json::to_string(value).expect("string serialization cannot fail")
}