kafkit-client 0.1.9

Kafka 4.0+ pure Rust client.
Documentation
use std::env;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;

use anyhow::{Context, Result};
use axum::Router;
use axum::extract::ws::{Message, WebSocket};
use axum::extract::{State, WebSocketUpgrade};
use axum::response::{Html, IntoResponse};
use axum::routing::get;
use kafkit_client::{
    AdminConfig, AutoOffsetReset, ConsumerRecord, KafkaAdmin, KafkaClient, KafkaMessage,
    KafkaProducer, NewTopic, RecordHeader,
};
use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;
use tokio::sync::broadcast;

const DEFAULT_BOOTSTRAP: &str = "localhost:9092";
const DEFAULT_TOPIC: &str = "kafkit.orders";
const DEFAULT_GROUP: &str = "kafkit-websocket-broadcaster";
const DEFAULT_BIND: &str = "127.0.0.1:8081";
const BROADCAST_CAPACITY: usize = 1024;

#[tokio::main]
async fn main() -> Result<()> {
    let bootstrap =
        env::var("KAFKIT_BOOTSTRAP_SERVERS").unwrap_or_else(|_| DEFAULT_BOOTSTRAP.to_owned());
    let topic = env::var("KAFKIT_TOPIC").unwrap_or_else(|_| DEFAULT_TOPIC.to_owned());
    let group_id = env::var("KAFKIT_GROUP_ID").unwrap_or_else(|_| DEFAULT_GROUP.to_owned());
    let bind: SocketAddr = env::var("KAFKIT_WS_BIND")
        .unwrap_or_else(|_| DEFAULT_BIND.to_owned())
        .parse()
        .context("KAFKIT_WS_BIND must be host:port, for example 127.0.0.1:8081")?;

    ensure_topic(&bootstrap, &topic).await?;

    let producer = KafkaClient::new(bootstrap.clone())
        .topic(topic.clone())
        .producer()
        .with_client_id("kafkit-websocket-producer")
        .connect()
        .await
        .context("failed to connect websocket producer")?;

    let (broadcast_tx, _) = broadcast::channel(BROADCAST_CAPACITY);
    let app_state = Arc::new(AppState {
        broadcast_tx: broadcast_tx.clone(),
        producer: Arc::new(producer),
    });

    tokio::spawn({
        let bootstrap = bootstrap.clone();
        let topic = topic.clone();
        let group_id = group_id.clone();
        async move {
            if let Err(error) =
                consume_and_broadcast(bootstrap, topic, group_id, broadcast_tx).await
            {
                eprintln!("background consumer stopped: {error:#}");
            }
        }
    });

    let app = Router::new()
        .route("/", get(index))
        .route("/health", get(health))
        .route("/ws", get(ws_handler))
        .with_state(app_state);

    println!("Kafka bootstrap: {bootstrap}");
    println!("Kafka topic: {topic}");
    println!("Kafka group: {group_id}");
    println!("WebSocket UI: http://{bind}");
    println!("WebSocket endpoint: ws://{bind}/ws");

    let listener = TcpListener::bind(bind)
        .await
        .with_context(|| format!("failed to bind {bind}"))?;
    axum::serve(listener, app)
        .await
        .context("web server failed")?;

    Ok(())
}

#[derive(Clone)]
struct AppState {
    broadcast_tx: broadcast::Sender<String>,
    producer: Arc<KafkaProducer>,
}

async fn health() -> &'static str {
    "ok"
}

async fn index() -> Html<&'static str> {
    Html(
        r#"<!doctype html>
<html>
<head>
  <meta charset="utf-8">
  <title>kafkit-client websocket broadcast</title>
  <style>
    body { font-family: ui-monospace, SFMono-Regular, Menlo, monospace; margin: 2rem; max-width: 72rem; }
    form { display: grid; gap: .75rem; grid-template-columns: 12rem 1fr auto; align-items: end; margin-bottom: 1rem; }
    label { display: grid; gap: .25rem; }
    input { font: inherit; padding: .5rem; }
    button { font: inherit; padding: .55rem .9rem; cursor: pointer; }
    pre { background: #111827; color: #e5e7eb; padding: 1rem; border-radius: 6px; min-height: 20rem; white-space: pre-wrap; }
  </style>
</head>
<body>
  <h1>Kafka records</h1>
  <p>Send a message over the websocket. The server produces it to Kafka; the background consumer broadcasts it back to every subscriber.</p>
  <form id="producer">
    <label>Key <input id="key" value="browser-1"></label>
    <label>Value <input id="value" value='{"source":"browser","status":"created"}'></label>
    <button type="submit">Produce</button>
  </form>
  <pre id="log"></pre>
  <script>
    const log = document.getElementById("log");
    const form = document.getElementById("producer");
    const key = document.getElementById("key");
    const value = document.getElementById("value");
    const ws = new WebSocket(`ws://${location.host}/ws`);
    ws.onopen = () => log.textContent += "connected\n";
    ws.onmessage = (event) => {
      log.textContent += JSON.stringify(JSON.parse(event.data), null, 2) + "\n\n";
      log.scrollTop = log.scrollHeight;
    };
    ws.onclose = () => log.textContent += "closed\n";
    ws.onerror = () => log.textContent += "websocket error\n";
    form.onsubmit = (event) => {
      event.preventDefault();
      ws.send(JSON.stringify({
        key: key.value || null,
        value: value.value,
        headers: { source: "browser-form" }
      }));
    };
  </script>
</body>
</html>"#,
    )
}

async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<AppState>>) -> impl IntoResponse {
    let rx = state.broadcast_tx.subscribe();
    ws.on_upgrade(move |socket| websocket_client(socket, rx, state.producer.clone()))
}

async fn websocket_client(
    mut socket: WebSocket,
    mut rx: broadcast::Receiver<String>,
    producer: Arc<KafkaProducer>,
) {
    if socket
        .send(Message::Text(
            r#"{"type":"connected","message":"subscribed to Kafka record stream; send JSON to produce records"}"#.into(),
        ))
        .await
        .is_err()
    {
        return;
    }

    loop {
        tokio::select! {
            message = rx.recv() => {
                match message {
                    Ok(message) => {
                        if socket.send(Message::Text(message.into())).await.is_err() {
                            break;
                        }
                    }
                    Err(broadcast::error::RecvError::Lagged(skipped)) => {
                        let warning = format!(
                            r#"{{"type":"lagged","skipped":{skipped}}}"#
                        );
                        if socket.send(Message::Text(warning.into())).await.is_err() {
                            break;
                        }
                    }
                    Err(broadcast::error::RecvError::Closed) => break,
                }
            }
            incoming = socket.recv() => {
                match incoming {
                    Some(Ok(Message::Close(_))) | None => break,
                    Some(Ok(Message::Text(text))) => {
                        let reply = match produce_from_websocket(&producer, &text).await {
                            Ok(ack) => ack,
                            Err(error) => format!(r#"{{"type":"produce_error","message":{}}}"#, json_string(&error.to_string())),
                        };
                        if socket.send(Message::Text(reply.into())).await.is_err() {
                            break;
                        }
                    }
                    Some(Ok(Message::Binary(bytes))) => {
                        let text = String::from_utf8_lossy(&bytes);
                        let reply = match produce_from_websocket(&producer, &text).await {
                            Ok(ack) => ack,
                            Err(error) => format!(r#"{{"type":"produce_error","message":{}}}"#, json_string(&error.to_string())),
                        };
                        if socket.send(Message::Text(reply.into())).await.is_err() {
                            break;
                        }
                    }
                    Some(Ok(_)) => {}
                    Some(Err(_)) => break,
                }
            }
        }
    }
}

async fn produce_from_websocket(producer: &KafkaProducer, text: &str) -> Result<String> {
    let input: ProduceCommand = serde_json::from_str(text)
        .context("expected JSON like {\"key\":\"id\",\"value\":\"...\"}")?;

    let mut message = match input.value {
        Some(value) => KafkaMessage::new(value),
        None => KafkaMessage::tombstone(),
    };

    if let Some(key) = input.key {
        message = message.with_key(key);
    }
    if let Some(partition) = input.partition {
        message = message.with_partition(partition);
    }
    if let Some(headers) = input.headers {
        for (key, value) in headers {
            message = message.with_header(RecordHeader::new(key, value));
        }
    }

    let ack = producer
        .send_message(message)
        .await
        .context("failed to produce websocket message")?;

    Ok(format!(
        r#"{{"type":"produced","topic":{},"partition":{},"base_offset":{}}}"#,
        json_string(&ack.topic),
        ack.partition,
        ack.base_offset
    ))
}

async fn consume_and_broadcast(
    bootstrap: String,
    topic: String,
    group_id: String,
    broadcast_tx: broadcast::Sender<String>,
) -> Result<()> {
    let consumer = KafkaClient::new(bootstrap)
        .topic(topic)
        .consumer(group_id)
        .with_client_id("kafkit-websocket-broadcast-consumer")
        .with_auto_offset_reset(AutoOffsetReset::Earliest)
        .with_poll_timeout(Duration::from_millis(500))
        .connect()
        .await
        .context("failed to connect background consumer")?;

    loop {
        let records = consumer
            .poll()
            .await
            .context("background consumer poll failed")?;
        if records.is_empty() {
            continue;
        }

        for record in records.iter() {
            let payload = serde_json::to_string(&BroadcastRecord::from(record))
                .context("failed to encode websocket record")?;
            let _ = broadcast_tx.send(payload);
        }

        consumer
            .commit(&records)
            .await
            .context("background consumer commit failed")?;
    }
}

async fn ensure_topic(bootstrap: &str, topic: &str) -> Result<()> {
    let admin = KafkaAdmin::connect(AdminConfig::new(bootstrap))
        .await
        .context("failed to connect admin client")?;

    let topics = admin.list_topics().await.context("failed to list topics")?;
    if topics.iter().any(|listing| listing.name == topic) {
        return Ok(());
    }

    admin
        .create_topics([NewTopic::new(topic, 3, 1)])
        .await
        .with_context(|| format!("failed to create topic {topic}"))?;
    Ok(())
}

#[derive(Deserialize)]
struct ProduceCommand {
    key: Option<String>,
    value: Option<String>,
    partition: Option<i32>,
    headers: Option<std::collections::BTreeMap<String, String>>,
}

#[derive(Serialize)]
struct BroadcastRecord {
    r#type: &'static str,
    topic: String,
    partition: i32,
    offset: i64,
    timestamp: i64,
    key: Option<String>,
    value: Option<String>,
    headers: Vec<BroadcastHeader>,
}

#[derive(Serialize)]
struct BroadcastHeader {
    key: String,
    value: Option<String>,
}

impl From<&ConsumerRecord> for BroadcastRecord {
    fn from(record: &ConsumerRecord) -> Self {
        Self {
            r#type: "record",
            topic: record.topic.clone(),
            partition: record.partition,
            offset: record.offset,
            timestamp: record.timestamp,
            key: record.key.as_deref().map(bytes_to_string),
            value: record.value.as_deref().map(bytes_to_string),
            headers: record
                .headers
                .iter()
                .map(|header| BroadcastHeader {
                    key: header.key.clone(),
                    value: header.value.as_deref().map(bytes_to_string),
                })
                .collect(),
        }
    }
}

fn bytes_to_string(bytes: &[u8]) -> String {
    String::from_utf8_lossy(bytes).into_owned()
}

fn json_string(value: &str) -> String {
    serde_json::to_string(value).expect("string serialization cannot fail")
}