use std::collections::HashSet;
use polars::prelude::{AnyValue, DataFrame, TimeUnit};
use crate::point::{escape_measurement, escape_string_field, escape_tag};
use crate::{error::Error, precision::Precision};
fn to_tag_value(val: AnyValue<'_>) -> Option<String> {
match val {
AnyValue::Null => None,
AnyValue::String(s) => Some(escape_tag(s).into_owned()),
AnyValue::StringOwned(s) => Some(escape_tag(s.as_str()).into_owned()),
other => Some(escape_tag(&format!("{other}")).into_owned()),
}
}
fn to_field_value(val: AnyValue<'_>) -> Option<String> {
match val {
AnyValue::Null => None,
AnyValue::Boolean(v) => Some(if v { "true".into() } else { "false".into() }),
AnyValue::Int8(v) => Some(format!("{v}i")),
AnyValue::Int16(v) => Some(format!("{v}i")),
AnyValue::Int32(v) => Some(format!("{v}i")),
AnyValue::Int64(v) => Some(format!("{v}i")),
AnyValue::Int128(v) => Some(format!("{v}i")),
AnyValue::UInt8(v) => Some(format!("{v}u")),
AnyValue::UInt16(v) => Some(format!("{v}u")),
AnyValue::UInt32(v) => Some(format!("{v}u")),
AnyValue::UInt64(v) => Some(format!("{v}u")),
AnyValue::UInt128(v) => Some(format!("{v}u")),
AnyValue::Float32(v) => {
if v.fract() == 0.0 && v.is_finite() {
Some(format!("{v}.0"))
} else {
Some(format!("{v}"))
}
}
AnyValue::Float64(v) => {
if v.fract() == 0.0 && v.is_finite() {
Some(format!("{v}.0"))
} else {
Some(format!("{v}"))
}
}
AnyValue::Float16(v) => {
let f = f32::from(v);
if f.fract() == 0.0 && f.is_finite() {
Some(format!("{f}.0"))
} else {
Some(format!("{f}"))
}
}
AnyValue::String(s) => Some(format!("\"{}\"", escape_string_field(s))),
AnyValue::StringOwned(s) => Some(format!("\"{}\"", escape_string_field(s.as_str()))),
AnyValue::Datetime(v, _, _) | AnyValue::DatetimeOwned(v, _, _) => Some(format!("{v}i")),
AnyValue::Date(v) => Some(format!("{v}i")),
AnyValue::Duration(v, _) => Some(format!("{v}i")),
AnyValue::Time(v) => Some(format!("{v}i")),
other => Some(format!("\"{}\"", escape_string_field(&format!("{other}")))),
}
}
fn to_timestamp(val: AnyValue<'_>, precision: Precision) -> Option<i64> {
match val {
AnyValue::Null => None,
AnyValue::Int64(v) => Some(v),
AnyValue::Int32(v) => Some(v as i64),
AnyValue::UInt64(v) => Some(v as i64),
AnyValue::UInt32(v) => Some(v as i64),
AnyValue::Datetime(v, tu, _) | AnyValue::DatetimeOwned(v, tu, _) => {
let nanos = match tu {
TimeUnit::Nanoseconds => v,
TimeUnit::Microseconds => v * 1_000,
TimeUnit::Milliseconds => v * 1_000_000,
};
Some(precision.scale_timestamp(nanos))
}
_ => None,
}
}
pub fn dataframe_to_line_protocol(
df: &DataFrame,
measurement: &str,
tags: &[&str],
timestamp_column: Option<&str>,
precision: Precision,
) -> Result<String, Error> {
let height = df.height();
if height == 0 {
return Ok(String::new());
}
let meas_escaped = escape_measurement(measurement);
let tag_set: HashSet<&str> = tags.iter().copied().collect();
let width = df.width();
let all_columns: Vec<&polars::frame::column::Column> =
(0..width).filter_map(|i| df.select_at_idx(i)).collect();
let mut lines: Vec<String> = Vec::with_capacity(height);
for row_idx in 0..height {
let mut line = String::with_capacity(128);
line.push_str(&meas_escaped);
for &tag in tags {
if let Ok(col) = df.column(tag) {
let val = col
.get(row_idx)
.map_err(|e| Error::Config(format!("polars row access error: {e}")))?;
if let Some(tv) = to_tag_value(val) {
line.push(',');
line.push_str(&escape_tag(tag));
line.push('=');
line.push_str(&tv);
}
}
}
line.push(' ');
let field_start = line.len();
let mut first_field = true;
for col in &all_columns {
let name = col.name().as_str();
if tag_set.contains(name) || Some(name) == timestamp_column {
continue;
}
let val = col
.get(row_idx)
.map_err(|e| Error::Config(format!("polars row access error: {e}")))?;
if let Some(fv) = to_field_value(val) {
if !first_field {
line.push(',');
}
line.push_str(&escape_tag(name));
line.push('=');
line.push_str(&fv);
first_field = false;
}
}
if line.len() == field_start {
continue;
}
if let Some(ts_col) = timestamp_column {
if let Ok(ts_column) = df.column(ts_col) {
let val = ts_column
.get(row_idx)
.map_err(|e| Error::Config(format!("polars row access error: {e}")))?;
if let Some(ts) = to_timestamp(val, precision) {
line.push(' ');
line.push_str(&ts.to_string());
}
}
}
lines.push(line);
}
Ok(lines.join("\n"))
}
pub struct DataFrameWrite<'a> {
df: &'a polars::frame::DataFrame,
measurement: String,
tags: Vec<String>,
timestamp_column: Option<String>,
}
impl<'a> DataFrameWrite<'a> {
pub fn new(df: &'a polars::frame::DataFrame, measurement: impl Into<String>) -> Self {
Self {
df,
measurement: measurement.into(),
tags: Vec::new(),
timestamp_column: None,
}
}
pub fn tags(mut self, tags: &[impl AsRef<str>]) -> Self {
self.tags = tags.iter().map(|s| s.as_ref().to_string()).collect();
self
}
pub fn timestamp_column(mut self, col: impl Into<String>) -> Self {
self.timestamp_column = Some(col.into());
self
}
}
impl crate::write::WriteInput for DataFrameWrite<'_> {
fn into_lp_batches(
self,
opts: &crate::write::WriteOptions,
) -> Box<dyn Iterator<Item = crate::Result<Vec<u8>>> + Send> {
let precision = opts.precision;
let batch_size = opts.batch_size.max(1);
let height = self.df.height();
let tag_refs: Vec<&str> = self.tags.iter().map(|s| s.as_str()).collect();
let ts_col = self.timestamp_column.as_deref();
let mut batches: Vec<crate::Result<Vec<u8>>> = Vec::new();
for start in (0..height).step_by(batch_size) {
let end = (start + batch_size).min(height);
let slice = self.df.slice(start as i64, end - start);
match dataframe_to_line_protocol(
&slice,
&self.measurement,
&tag_refs,
ts_col,
precision,
) {
Ok(lp) if !lp.is_empty() => batches.push(Ok(lp.into_bytes())),
Ok(_) => {}
Err(e) => {
batches.push(Err(e));
break;
}
}
}
Box::new(batches.into_iter())
}
}
#[cfg(test)]
mod tests {
use super::*;
use polars::prelude::*;
#[test]
fn full_serialisation() {
let df = df![
"host" => ["srv,1"],
"msg" => [r#"say "hi""#],
"cpu_pct" => [42.5_f64],
"mem_mb" => [8192_i64],
"online" => [true],
"ts" => [1_700_000_000_000_i64],
]
.unwrap();
let lp = dataframe_to_line_protocol(
&df,
"m,name",
&["host"],
Some("ts"),
Precision::Millisecond,
)
.unwrap();
assert!(lp.starts_with(r"m\,name,host=srv\,1 "), "got: {lp}");
assert!(lp.contains("cpu_pct=42.5"));
assert!(lp.contains("mem_mb=8192i"));
assert!(lp.contains("online=true"));
assert!(lp.contains(r#"msg="say \"hi\"""#));
assert!(lp.ends_with("1700000000000"));
assert!(!lp.split(' ').nth(1).unwrap().contains("host="));
}
#[test]
fn null_and_empty_handling() {
let df = df![
"v" => [Some(1.0_f64), None::<f64>],
"ts" => [100_i64, 200_i64],
]
.unwrap();
let lp =
dataframe_to_line_protocol(&df, "m", &[], Some("ts"), Precision::Nanosecond).unwrap();
assert_eq!(lp.lines().count(), 1);
assert!(lp.contains("v=1.0"));
let df = df!["v" => Vec::<i64>::new()].unwrap();
assert!(
dataframe_to_line_protocol(&df, "m", &[], None, Precision::Nanosecond)
.unwrap()
.is_empty()
);
}
}