1use std::time::Duration;
13
14use kyu_common::{KyuError, KyuResult};
15use kyu_types::{LogicalType, TypedValue};
16use rdkafka::config::ClientConfig;
17use rdkafka::consumer::{BaseConsumer, Consumer};
18use rdkafka::message::Message;
19use smol_str::SmolStr;
20
21use crate::{DataReader, parse_field};
22
23const POLL_TIMEOUT: Duration = Duration::from_secs(5);
25
26const DEFAULT_BATCH_SIZE: usize = 10_000;
28
29pub struct KafkaReader {
34 consumer: BaseConsumer,
35 schema: Vec<LogicalType>,
36 column_names: Vec<SmolStr>,
37 remaining: usize,
38 finished: bool,
39}
40
41pub struct KafkaUrl {
43 pub brokers: String,
44 pub topic: String,
45 pub group_id: Option<String>,
46}
47
48pub fn parse_kafka_url(url: &str) -> KyuResult<KafkaUrl> {
50 let stripped = url
51 .strip_prefix("kafka://")
52 .ok_or_else(|| KyuError::Copy("Kafka URL must start with 'kafka://'".into()))?;
53
54 let (path, query) = stripped.split_once('?').unwrap_or((stripped, ""));
56 let group_id = query
57 .split('&')
58 .find_map(|kv| kv.strip_prefix("group_id="))
59 .map(String::from);
60
61 let (brokers, topic) = path
63 .rsplit_once('/')
64 .ok_or_else(|| KyuError::Copy("Kafka URL must contain /topic after broker".into()))?;
65
66 if brokers.is_empty() || topic.is_empty() {
67 return Err(KyuError::Copy(
68 "Kafka URL must have non-empty broker and topic".into(),
69 ));
70 }
71
72 Ok(KafkaUrl {
73 brokers: brokers.to_string(),
74 topic: topic.to_string(),
75 group_id,
76 })
77}
78
79impl KafkaReader {
80 pub fn open(
89 brokers: &str,
90 topic: &str,
91 group_id: &str,
92 schema: &[LogicalType],
93 column_names: &[SmolStr],
94 batch_size: usize,
95 ) -> KyuResult<Self> {
96 let consumer: BaseConsumer = ClientConfig::new()
97 .set("bootstrap.servers", brokers)
98 .set("group.id", group_id)
99 .set("auto.offset.reset", "earliest")
100 .set("enable.auto.commit", "true")
101 .create()
102 .map_err(|e| KyuError::Copy(format!("Kafka consumer error: {e}")))?;
103
104 consumer
105 .subscribe(&[topic])
106 .map_err(|e| KyuError::Copy(format!("Kafka subscribe error: {e}")))?;
107
108 let batch = if batch_size == 0 {
109 DEFAULT_BATCH_SIZE
110 } else {
111 batch_size
112 };
113
114 Ok(Self {
115 consumer,
116 schema: schema.to_vec(),
117 column_names: column_names.to_vec(),
118 remaining: batch,
119 finished: false,
120 })
121 }
122
123 pub fn from_url(
125 url: &str,
126 schema: &[LogicalType],
127 column_names: &[SmolStr],
128 ) -> KyuResult<Self> {
129 let parsed = parse_kafka_url(url)?;
130 let group_id = parsed
131 .group_id
132 .unwrap_or_else(|| format!("kyugraph-{}", parsed.topic));
133 Self::open(
134 &parsed.brokers,
135 &parsed.topic,
136 &group_id,
137 schema,
138 column_names,
139 DEFAULT_BATCH_SIZE,
140 )
141 }
142
143 fn parse_json_message(&self, payload: &[u8]) -> KyuResult<Vec<TypedValue>> {
145 let mut buf = payload.to_vec();
147 let value: simd_json::OwnedValue = simd_json::to_owned_value(&mut buf)
148 .map_err(|e| KyuError::Copy(format!("JSON parse error: {e}")))?;
149
150 let obj = match &value {
151 simd_json::OwnedValue::Object(map) => map,
152 _ => return Err(KyuError::Copy("Kafka message must be a JSON object".into())),
153 };
154
155 let mut row = Vec::with_capacity(self.schema.len());
156 for (col_idx, col_type) in self.schema.iter().enumerate() {
157 let col_name = &self.column_names[col_idx];
158 let val = match obj.get(col_name.as_str()) {
159 Some(json_val) => json_value_to_typed(json_val, col_type)?,
160 None => TypedValue::Null,
161 };
162 row.push(val);
163 }
164 Ok(row)
165 }
166}
167
168impl DataReader for KafkaReader {
169 fn schema(&self) -> &[LogicalType] {
170 &self.schema
171 }
172}
173
174impl Iterator for KafkaReader {
175 type Item = KyuResult<Vec<TypedValue>>;
176
177 fn next(&mut self) -> Option<Self::Item> {
178 if self.finished || self.remaining == 0 {
179 return None;
180 }
181
182 match self.consumer.poll(POLL_TIMEOUT) {
184 Some(Ok(msg)) => {
185 self.remaining -= 1;
186 match msg.payload() {
187 Some(payload) => Some(self.parse_json_message(payload)),
188 None => {
189 self.next()
191 }
192 }
193 }
194 Some(Err(e)) => {
195 self.finished = true;
196 Some(Err(KyuError::Copy(format!("Kafka poll error: {e}"))))
197 }
198 None => {
199 self.finished = true;
201 None
202 }
203 }
204 }
205}
206
207fn json_value_to_typed(val: &simd_json::OwnedValue, target: &LogicalType) -> KyuResult<TypedValue> {
209 use simd_json::OwnedValue;
210
211 if matches!(val, OwnedValue::Static(simd_json::StaticNode::Null)) {
212 return Ok(TypedValue::Null);
213 }
214
215 match target {
216 LogicalType::Bool => match val {
217 OwnedValue::Static(simd_json::StaticNode::Bool(b)) => Ok(TypedValue::Bool(*b)),
218 _ => Err(KyuError::Copy(format!("expected bool, got {val:?}"))),
219 },
220 LogicalType::Int8 => match val {
221 OwnedValue::Static(simd_json::StaticNode::I64(n)) => Ok(TypedValue::Int8(*n as i8)),
222 _ => Err(KyuError::Copy(format!("expected int8, got {val:?}"))),
223 },
224 LogicalType::Int16 => match val {
225 OwnedValue::Static(simd_json::StaticNode::I64(n)) => Ok(TypedValue::Int16(*n as i16)),
226 _ => Err(KyuError::Copy(format!("expected int16, got {val:?}"))),
227 },
228 LogicalType::Int32 => match val {
229 OwnedValue::Static(simd_json::StaticNode::I64(n)) => Ok(TypedValue::Int32(*n as i32)),
230 _ => Err(KyuError::Copy(format!("expected int32, got {val:?}"))),
231 },
232 LogicalType::Int64 | LogicalType::Serial => match val {
233 OwnedValue::Static(simd_json::StaticNode::I64(n)) => Ok(TypedValue::Int64(*n)),
234 OwnedValue::Static(simd_json::StaticNode::U64(n)) => Ok(TypedValue::Int64(*n as i64)),
235 _ => Err(KyuError::Copy(format!("expected int64, got {val:?}"))),
236 },
237 LogicalType::Float => match val {
238 OwnedValue::Static(simd_json::StaticNode::F64(f)) => Ok(TypedValue::Float(*f as f32)),
239 OwnedValue::Static(simd_json::StaticNode::I64(n)) => Ok(TypedValue::Float(*n as f32)),
240 _ => Err(KyuError::Copy(format!("expected float, got {val:?}"))),
241 },
242 LogicalType::Double => match val {
243 OwnedValue::Static(simd_json::StaticNode::F64(f)) => Ok(TypedValue::Double(*f)),
244 OwnedValue::Static(simd_json::StaticNode::I64(n)) => Ok(TypedValue::Double(*n as f64)),
245 _ => Err(KyuError::Copy(format!("expected double, got {val:?}"))),
246 },
247 LogicalType::String => match val {
248 OwnedValue::String(s) => Ok(TypedValue::String(SmolStr::new(s))),
249 other => {
250 Ok(TypedValue::String(SmolStr::new(format!("{other}"))))
252 }
253 },
254 _ => {
255 let s = match val {
257 OwnedValue::String(s) => s.clone(),
258 other => format!("{other}"),
259 };
260 parse_field(&s, target)
261 }
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268
269 #[test]
270 fn parse_kafka_url_basic() {
271 let url = parse_kafka_url("kafka://localhost:9092/my_topic").unwrap();
272 assert_eq!(url.brokers, "localhost:9092");
273 assert_eq!(url.topic, "my_topic");
274 assert!(url.group_id.is_none());
275 }
276
277 #[test]
278 fn parse_kafka_url_with_group() {
279 let url =
280 parse_kafka_url("kafka://broker1:9092,broker2:9092/events?group_id=ingest").unwrap();
281 assert_eq!(url.brokers, "broker1:9092,broker2:9092");
282 assert_eq!(url.topic, "events");
283 assert_eq!(url.group_id.as_deref(), Some("ingest"));
284 }
285
286 #[test]
287 fn parse_kafka_url_invalid() {
288 assert!(parse_kafka_url("http://localhost:9092/topic").is_err());
289 assert!(parse_kafka_url("kafka:///topic").is_err());
290 assert!(parse_kafka_url("kafka://broker/").is_err());
291 }
292
293 #[test]
294 fn json_value_bool() {
295 let val = simd_json::OwnedValue::Static(simd_json::StaticNode::Bool(true));
296 let result = json_value_to_typed(&val, &LogicalType::Bool).unwrap();
297 assert_eq!(result, TypedValue::Bool(true));
298 }
299
300 #[test]
301 fn json_value_int64() {
302 let val = simd_json::OwnedValue::Static(simd_json::StaticNode::I64(42));
303 let result = json_value_to_typed(&val, &LogicalType::Int64).unwrap();
304 assert_eq!(result, TypedValue::Int64(42));
305 }
306
307 #[test]
308 fn json_value_double() {
309 let val = simd_json::OwnedValue::Static(simd_json::StaticNode::F64(3.14));
310 let result = json_value_to_typed(&val, &LogicalType::Double).unwrap();
311 assert_eq!(result, TypedValue::Double(3.14));
312 }
313
314 #[test]
315 fn json_value_string() {
316 let val = simd_json::OwnedValue::String("hello".to_string());
317 let result = json_value_to_typed(&val, &LogicalType::String).unwrap();
318 assert_eq!(result, TypedValue::String(SmolStr::new("hello")));
319 }
320
321 #[test]
322 fn json_value_null() {
323 let val = simd_json::OwnedValue::Static(simd_json::StaticNode::Null);
324 let result = json_value_to_typed(&val, &LogicalType::Int64).unwrap();
325 assert_eq!(result, TypedValue::Null);
326 }
327
328 #[test]
329 fn json_value_int_to_double() {
330 let val = simd_json::OwnedValue::Static(simd_json::StaticNode::I64(7));
331 let result = json_value_to_typed(&val, &LogicalType::Double).unwrap();
332 assert_eq!(result, TypedValue::Double(7.0));
333 }
334
335 #[test]
336 fn json_parse_message() {
337 let reader = KafkaReader {
338 consumer: ClientConfig::new()
339 .set("bootstrap.servers", "localhost:9092")
340 .set("group.id", "test")
341 .create()
342 .unwrap(),
343 schema: vec![LogicalType::Int64, LogicalType::String],
344 column_names: vec![SmolStr::new("id"), SmolStr::new("name")],
345 remaining: 100,
346 finished: false,
347 };
348
349 let payload = br#"{"id": 42, "name": "Alice"}"#;
350 let row = reader.parse_json_message(payload).unwrap();
351 assert_eq!(row.len(), 2);
352 assert_eq!(row[0], TypedValue::Int64(42));
353 assert_eq!(row[1], TypedValue::String(SmolStr::new("Alice")));
354 }
355
356 #[test]
357 fn json_parse_missing_field() {
358 let reader = KafkaReader {
359 consumer: ClientConfig::new()
360 .set("bootstrap.servers", "localhost:9092")
361 .set("group.id", "test")
362 .create()
363 .unwrap(),
364 schema: vec![LogicalType::Int64, LogicalType::String],
365 column_names: vec![SmolStr::new("id"), SmolStr::new("name")],
366 remaining: 100,
367 finished: false,
368 };
369
370 let payload = br#"{"id": 99}"#;
371 let row = reader.parse_json_message(payload).unwrap();
372 assert_eq!(row.len(), 2);
373 assert_eq!(row[0], TypedValue::Int64(99));
374 assert_eq!(row[1], TypedValue::Null);
375 }
376}