Skip to main content

kyu_copy/
kafka_reader.rs

1//! Kafka topic reader for streaming ingestion.
2//!
3//! Consumes JSON-encoded messages from a Kafka topic and yields rows
4//! matching the target table schema. Each JSON message must be a flat
5//! object whose keys correspond to column names.
6//!
7//! Usage (via Cypher):
8//! ```cypher
9//! COPY Person FROM 'kafka://broker:9092/topic_name'
10//! ```
11
12use std::time::Duration;
13
14use kyu_common::{KyuError, KyuResult};
15use kyu_types::{LogicalType, TypedValue};
16use rdkafka::config::ClientConfig;
17use rdkafka::consumer::{BaseConsumer, Consumer};
18use rdkafka::message::Message;
19use smol_str::SmolStr;
20
21use crate::{DataReader, parse_field};
22
23/// Default timeout for polling a single message.
24const POLL_TIMEOUT: Duration = Duration::from_secs(5);
25
26/// Maximum number of messages to consume per batch.
27const DEFAULT_BATCH_SIZE: usize = 10_000;
28
29/// Reads rows from a Kafka topic.
30///
31/// Each consumed message is expected to be a JSON object with keys matching
32/// the `column_names` provided at construction. Missing keys yield NULL.
33pub struct KafkaReader {
34    consumer: BaseConsumer,
35    schema: Vec<LogicalType>,
36    column_names: Vec<SmolStr>,
37    remaining: usize,
38    finished: bool,
39}
40
41/// Parsed Kafka connection info from a URL like `kafka://broker:9092/topic`.
42pub struct KafkaUrl {
43    pub brokers: String,
44    pub topic: String,
45    pub group_id: Option<String>,
46}
47
48/// Parse a Kafka URL: `kafka://host:port/topic[?group_id=xxx]`
49pub fn parse_kafka_url(url: &str) -> KyuResult<KafkaUrl> {
50    let stripped = url
51        .strip_prefix("kafka://")
52        .ok_or_else(|| KyuError::Copy("Kafka URL must start with 'kafka://'".into()))?;
53
54    // Split off query string.
55    let (path, query) = stripped.split_once('?').unwrap_or((stripped, ""));
56    let group_id = query
57        .split('&')
58        .find_map(|kv| kv.strip_prefix("group_id="))
59        .map(String::from);
60
61    // Split broker from topic.
62    let (brokers, topic) = path
63        .rsplit_once('/')
64        .ok_or_else(|| KyuError::Copy("Kafka URL must contain /topic after broker".into()))?;
65
66    if brokers.is_empty() || topic.is_empty() {
67        return Err(KyuError::Copy(
68            "Kafka URL must have non-empty broker and topic".into(),
69        ));
70    }
71
72    Ok(KafkaUrl {
73        brokers: brokers.to_string(),
74        topic: topic.to_string(),
75        group_id,
76    })
77}
78
79impl KafkaReader {
80    /// Connect to a Kafka topic and prepare for consumption.
81    ///
82    /// - `brokers`: comma-separated Kafka broker addresses (e.g., "localhost:9092")
83    /// - `topic`: Kafka topic to consume from
84    /// - `group_id`: consumer group ID
85    /// - `schema`: target column types
86    /// - `column_names`: column names corresponding to JSON keys
87    /// - `batch_size`: maximum messages to consume (0 = use default)
88    pub fn open(
89        brokers: &str,
90        topic: &str,
91        group_id: &str,
92        schema: &[LogicalType],
93        column_names: &[SmolStr],
94        batch_size: usize,
95    ) -> KyuResult<Self> {
96        let consumer: BaseConsumer = ClientConfig::new()
97            .set("bootstrap.servers", brokers)
98            .set("group.id", group_id)
99            .set("auto.offset.reset", "earliest")
100            .set("enable.auto.commit", "true")
101            .create()
102            .map_err(|e| KyuError::Copy(format!("Kafka consumer error: {e}")))?;
103
104        consumer
105            .subscribe(&[topic])
106            .map_err(|e| KyuError::Copy(format!("Kafka subscribe error: {e}")))?;
107
108        let batch = if batch_size == 0 {
109            DEFAULT_BATCH_SIZE
110        } else {
111            batch_size
112        };
113
114        Ok(Self {
115            consumer,
116            schema: schema.to_vec(),
117            column_names: column_names.to_vec(),
118            remaining: batch,
119            finished: false,
120        })
121    }
122
123    /// Open from a Kafka URL.
124    pub fn from_url(
125        url: &str,
126        schema: &[LogicalType],
127        column_names: &[SmolStr],
128    ) -> KyuResult<Self> {
129        let parsed = parse_kafka_url(url)?;
130        let group_id = parsed
131            .group_id
132            .unwrap_or_else(|| format!("kyugraph-{}", parsed.topic));
133        Self::open(
134            &parsed.brokers,
135            &parsed.topic,
136            &group_id,
137            schema,
138            column_names,
139            DEFAULT_BATCH_SIZE,
140        )
141    }
142
143    /// Parse a JSON message payload into a row of TypedValues.
144    fn parse_json_message(&self, payload: &[u8]) -> KyuResult<Vec<TypedValue>> {
145        // Parse JSON with simd-json for performance.
146        let mut buf = payload.to_vec();
147        let value: simd_json::OwnedValue = simd_json::to_owned_value(&mut buf)
148            .map_err(|e| KyuError::Copy(format!("JSON parse error: {e}")))?;
149
150        let obj = match &value {
151            simd_json::OwnedValue::Object(map) => map,
152            _ => return Err(KyuError::Copy("Kafka message must be a JSON object".into())),
153        };
154
155        let mut row = Vec::with_capacity(self.schema.len());
156        for (col_idx, col_type) in self.schema.iter().enumerate() {
157            let col_name = &self.column_names[col_idx];
158            let val = match obj.get(col_name.as_str()) {
159                Some(json_val) => json_value_to_typed(json_val, col_type)?,
160                None => TypedValue::Null,
161            };
162            row.push(val);
163        }
164        Ok(row)
165    }
166}
167
168impl DataReader for KafkaReader {
169    fn schema(&self) -> &[LogicalType] {
170        &self.schema
171    }
172}
173
174impl Iterator for KafkaReader {
175    type Item = KyuResult<Vec<TypedValue>>;
176
177    fn next(&mut self) -> Option<Self::Item> {
178        if self.finished || self.remaining == 0 {
179            return None;
180        }
181
182        // Poll for a message with timeout.
183        match self.consumer.poll(POLL_TIMEOUT) {
184            Some(Ok(msg)) => {
185                self.remaining -= 1;
186                match msg.payload() {
187                    Some(payload) => Some(self.parse_json_message(payload)),
188                    None => {
189                        // Tombstone (null payload) → skip.
190                        self.next()
191                    }
192                }
193            }
194            Some(Err(e)) => {
195                self.finished = true;
196                Some(Err(KyuError::Copy(format!("Kafka poll error: {e}"))))
197            }
198            None => {
199                // Timeout — no more messages available.
200                self.finished = true;
201                None
202            }
203        }
204    }
205}
206
207/// Convert a simd-json value to a TypedValue.
208fn json_value_to_typed(val: &simd_json::OwnedValue, target: &LogicalType) -> KyuResult<TypedValue> {
209    use simd_json::OwnedValue;
210
211    if matches!(val, OwnedValue::Static(simd_json::StaticNode::Null)) {
212        return Ok(TypedValue::Null);
213    }
214
215    match target {
216        LogicalType::Bool => match val {
217            OwnedValue::Static(simd_json::StaticNode::Bool(b)) => Ok(TypedValue::Bool(*b)),
218            _ => Err(KyuError::Copy(format!("expected bool, got {val:?}"))),
219        },
220        LogicalType::Int8 => match val {
221            OwnedValue::Static(simd_json::StaticNode::I64(n)) => Ok(TypedValue::Int8(*n as i8)),
222            _ => Err(KyuError::Copy(format!("expected int8, got {val:?}"))),
223        },
224        LogicalType::Int16 => match val {
225            OwnedValue::Static(simd_json::StaticNode::I64(n)) => Ok(TypedValue::Int16(*n as i16)),
226            _ => Err(KyuError::Copy(format!("expected int16, got {val:?}"))),
227        },
228        LogicalType::Int32 => match val {
229            OwnedValue::Static(simd_json::StaticNode::I64(n)) => Ok(TypedValue::Int32(*n as i32)),
230            _ => Err(KyuError::Copy(format!("expected int32, got {val:?}"))),
231        },
232        LogicalType::Int64 | LogicalType::Serial => match val {
233            OwnedValue::Static(simd_json::StaticNode::I64(n)) => Ok(TypedValue::Int64(*n)),
234            OwnedValue::Static(simd_json::StaticNode::U64(n)) => Ok(TypedValue::Int64(*n as i64)),
235            _ => Err(KyuError::Copy(format!("expected int64, got {val:?}"))),
236        },
237        LogicalType::Float => match val {
238            OwnedValue::Static(simd_json::StaticNode::F64(f)) => Ok(TypedValue::Float(*f as f32)),
239            OwnedValue::Static(simd_json::StaticNode::I64(n)) => Ok(TypedValue::Float(*n as f32)),
240            _ => Err(KyuError::Copy(format!("expected float, got {val:?}"))),
241        },
242        LogicalType::Double => match val {
243            OwnedValue::Static(simd_json::StaticNode::F64(f)) => Ok(TypedValue::Double(*f)),
244            OwnedValue::Static(simd_json::StaticNode::I64(n)) => Ok(TypedValue::Double(*n as f64)),
245            _ => Err(KyuError::Copy(format!("expected double, got {val:?}"))),
246        },
247        LogicalType::String => match val {
248            OwnedValue::String(s) => Ok(TypedValue::String(SmolStr::new(s))),
249            other => {
250                // Coerce non-string to string representation.
251                Ok(TypedValue::String(SmolStr::new(format!("{other}"))))
252            }
253        },
254        _ => {
255            // Fallback: convert to string via parse_field.
256            let s = match val {
257                OwnedValue::String(s) => s.clone(),
258                other => format!("{other}"),
259            };
260            parse_field(&s, target)
261        }
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    #[test]
270    fn parse_kafka_url_basic() {
271        let url = parse_kafka_url("kafka://localhost:9092/my_topic").unwrap();
272        assert_eq!(url.brokers, "localhost:9092");
273        assert_eq!(url.topic, "my_topic");
274        assert!(url.group_id.is_none());
275    }
276
277    #[test]
278    fn parse_kafka_url_with_group() {
279        let url =
280            parse_kafka_url("kafka://broker1:9092,broker2:9092/events?group_id=ingest").unwrap();
281        assert_eq!(url.brokers, "broker1:9092,broker2:9092");
282        assert_eq!(url.topic, "events");
283        assert_eq!(url.group_id.as_deref(), Some("ingest"));
284    }
285
286    #[test]
287    fn parse_kafka_url_invalid() {
288        assert!(parse_kafka_url("http://localhost:9092/topic").is_err());
289        assert!(parse_kafka_url("kafka:///topic").is_err());
290        assert!(parse_kafka_url("kafka://broker/").is_err());
291    }
292
293    #[test]
294    fn json_value_bool() {
295        let val = simd_json::OwnedValue::Static(simd_json::StaticNode::Bool(true));
296        let result = json_value_to_typed(&val, &LogicalType::Bool).unwrap();
297        assert_eq!(result, TypedValue::Bool(true));
298    }
299
300    #[test]
301    fn json_value_int64() {
302        let val = simd_json::OwnedValue::Static(simd_json::StaticNode::I64(42));
303        let result = json_value_to_typed(&val, &LogicalType::Int64).unwrap();
304        assert_eq!(result, TypedValue::Int64(42));
305    }
306
307    #[test]
308    fn json_value_double() {
309        let val = simd_json::OwnedValue::Static(simd_json::StaticNode::F64(3.14));
310        let result = json_value_to_typed(&val, &LogicalType::Double).unwrap();
311        assert_eq!(result, TypedValue::Double(3.14));
312    }
313
314    #[test]
315    fn json_value_string() {
316        let val = simd_json::OwnedValue::String("hello".to_string());
317        let result = json_value_to_typed(&val, &LogicalType::String).unwrap();
318        assert_eq!(result, TypedValue::String(SmolStr::new("hello")));
319    }
320
321    #[test]
322    fn json_value_null() {
323        let val = simd_json::OwnedValue::Static(simd_json::StaticNode::Null);
324        let result = json_value_to_typed(&val, &LogicalType::Int64).unwrap();
325        assert_eq!(result, TypedValue::Null);
326    }
327
328    #[test]
329    fn json_value_int_to_double() {
330        let val = simd_json::OwnedValue::Static(simd_json::StaticNode::I64(7));
331        let result = json_value_to_typed(&val, &LogicalType::Double).unwrap();
332        assert_eq!(result, TypedValue::Double(7.0));
333    }
334
335    #[test]
336    fn json_parse_message() {
337        let reader = KafkaReader {
338            consumer: ClientConfig::new()
339                .set("bootstrap.servers", "localhost:9092")
340                .set("group.id", "test")
341                .create()
342                .unwrap(),
343            schema: vec![LogicalType::Int64, LogicalType::String],
344            column_names: vec![SmolStr::new("id"), SmolStr::new("name")],
345            remaining: 100,
346            finished: false,
347        };
348
349        let payload = br#"{"id": 42, "name": "Alice"}"#;
350        let row = reader.parse_json_message(payload).unwrap();
351        assert_eq!(row.len(), 2);
352        assert_eq!(row[0], TypedValue::Int64(42));
353        assert_eq!(row[1], TypedValue::String(SmolStr::new("Alice")));
354    }
355
356    #[test]
357    fn json_parse_missing_field() {
358        let reader = KafkaReader {
359            consumer: ClientConfig::new()
360                .set("bootstrap.servers", "localhost:9092")
361                .set("group.id", "test")
362                .create()
363                .unwrap(),
364            schema: vec![LogicalType::Int64, LogicalType::String],
365            column_names: vec![SmolStr::new("id"), SmolStr::new("name")],
366            remaining: 100,
367            finished: false,
368        };
369
370        let payload = br#"{"id": 99}"#;
371        let row = reader.parse_json_message(payload).unwrap();
372        assert_eq!(row.len(), 2);
373        assert_eq!(row[0], TypedValue::Int64(99));
374        assert_eq!(row[1], TypedValue::Null);
375    }
376}