#[cfg(feature = "python")]
use crate::analyze::utils::format_standard_kline;
use crate::objects::{
bar::{NewBar, RawBar, Symbol},
bi::BI,
direction::Direction,
freq::Freq,
fx::FX,
mark::Mark,
};
use derive_builder::Builder;
#[cfg(feature = "python")]
use parking_lot::RwLock;
#[cfg(feature = "python")]
use polars::prelude::*;
#[cfg(feature = "python")]
use std::io::Cursor;
#[cfg(feature = "python")]
use std::sync::Arc;
use utils::{check_bi, check_fxs, remove_include};
pub mod errors;
pub mod utils;
#[cfg(feature = "python")]
use crate::utils::common::freq_to_chinese_string;
#[cfg(feature = "python")]
use crate::utils::common::{create_naive_pandas_timestamp, create_ordered_dict};
#[cfg(feature = "python")]
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
#[cfg(feature = "python")]
use pyo3::types::{PyBytesMethods, PyDict};
#[cfg(feature = "python")]
use pyo3::{Py, PyErr, PyObject, PyResult, Python};
#[cfg(feature = "python")]
use pyo3::{pyclass, pymethods};
#[cfg(feature = "python")]
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
#[cfg_attr(feature = "python", gen_stub_pyclass)]
#[cfg_attr(feature = "python", pyclass(module = "czsc._native"))]
#[derive(Debug, Clone, Builder)]
pub struct CZSC {
pub max_bi_num: usize,
pub bars_raw: Vec<RawBar>,
pub bars_ubi: Vec<NewBar>,
pub bi_list: Vec<BI>,
pub symbol: Symbol,
pub freq: Freq,
#[cfg(feature = "python")]
#[builder(default = "Arc::new(RwLock::new(None))")]
pub cache: Arc<RwLock<Option<Py<PyDict>>>>,
}
impl CZSC {
fn sync_extended_last_ubi_in_bis(&mut self, last_ubi: &NewBar, bar: &RawBar) {
#[inline]
fn patch_new_bar_if_same(nb: &mut NewBar, target: &NewBar, bar: &RawBar) {
if nb == target
&& let Some(last) = nb.elements.last_mut()
&& last.dt == bar.dt
{
*last = bar.clone();
}
}
#[inline]
fn patch_fx_if_same(fx: &mut FX, target: &NewBar, bar: &RawBar) {
for nb in &mut fx.elements {
patch_new_bar_if_same(nb, target, bar);
}
}
for bi in &mut self.bi_list {
for nb in &mut bi.bars {
patch_new_bar_if_same(nb, last_ubi, bar);
}
patch_fx_if_same(&mut bi.fx_a, last_ubi, bar);
patch_fx_if_same(&mut bi.fx_b, last_ubi, bar);
for fx in &mut bi.fxs {
patch_fx_if_same(fx, last_ubi, bar);
}
}
}
pub fn new(bars_raw: Vec<RawBar>, max_bi_num: usize) -> Self {
let mut c = Self {
max_bi_num,
bars_raw: Vec::with_capacity(bars_raw.len()), bars_ubi: Vec::with_capacity(bars_raw.len() / 2), bi_list: Vec::with_capacity(max_bi_num.min(bars_raw.len() / 10)), symbol: bars_raw[0].symbol.clone(),
freq: bars_raw[0].freq,
#[cfg(feature = "python")]
cache: Arc::new(RwLock::new(None)),
};
for b in bars_raw {
c.update_bar(b);
}
c
}
pub fn get_fx_list(&self) -> Vec<FX> {
let mut fxs = Vec::new();
for bi_ in self.bi_list.iter() {
fxs.extend_from_slice(&bi_.fxs[1..]);
}
if let Some(ubi_fxs) = self.get_ubi_fxs() {
for x in ubi_fxs {
if fxs.is_empty() || x.dt > fxs.last().unwrap().dt {
fxs.push(x);
}
}
}
fxs
}
pub fn update_bar(&mut self, bar: RawBar) {
let last_bars = if self.bars_raw.is_empty() || bar.dt != self.bars_raw.last().unwrap().dt {
self.bars_raw.push(bar.clone());
vec![bar]
} else {
*self.bars_raw.last_mut().unwrap() = bar.clone();
let last_ubi = self.bars_ubi.pop().unwrap();
self.sync_extended_last_ubi_in_bis(&last_ubi, &bar);
let mut last_bars = last_ubi.elements.to_vec();
assert_eq!(
bar.dt,
last_bars.last().unwrap().dt,
"时间错位: {} != {}",
bar.dt,
last_bars.last().unwrap().dt
);
*last_bars.last_mut().unwrap() = bar;
last_bars
};
for bar in last_bars.iter() {
if self.bars_ubi.len() < 2 {
self.bars_ubi.push(NewBar::new_from_raw(bar));
} else {
let (has_include, k3) = {
let idx = self.bars_ubi.len() - 2;
let (_, last_two) = self.bars_ubi.split_at_mut(idx);
let k1 = &last_two[0]; let k2 = &last_two[1]; remove_include(k1, k2, bar.clone()).unwrap()
};
if has_include {
*self.bars_ubi.last_mut().unwrap() = k3;
} else {
self.bars_ubi.push(k3);
}
}
}
self.__update_bi();
if self.bi_list.len() > self.max_bi_num {
let start_idx = self.bi_list.len() - self.max_bi_num;
self.bi_list.drain(0..start_idx);
}
if !self.bi_list.is_empty() {
let sdt = self.bi_list.first().unwrap().fx_a.elements[0].dt;
let drain_to = self.bars_raw.partition_point(|bar| bar.dt < sdt);
self.bars_raw.drain(0..drain_to);
}
}
fn __update_bi(&mut self) -> Option<()> {
if self.bars_ubi.len() < 3 {
return None;
}
if self.bi_list.is_empty() {
let fxs = check_fxs(&self.bars_ubi);
let first = fxs.first()?;
let fx_a = fxs
.iter()
.filter(|x| x.mark == first.mark)
.reduce(|acc, x| match first.mark {
Mark::D if x.low <= acc.low => x,
Mark::G if x.high >= acc.high => x,
_ => acc,
})
.unwrap_or(first);
let bars_ubi = self
.bars_ubi
.iter()
.filter(|x| x.dt >= fx_a.elements[0].dt)
.collect::<Vec<_>>();
let (bi, bars_ubi_) = check_bi(&bars_ubi);
if let Some(bi) = bi {
self.bi_list.push(bi);
}
self.bars_ubi = bars_ubi_.iter().map(|&bar| bar.clone()).collect::<Vec<_>>();
return None;
}
let (bi, bars_ubi_) = check_bi(&self.bars_ubi);
if let Some(bi) = bi {
self.bi_list.push(bi);
}
self.bars_ubi = bars_ubi_.to_vec();
let last_bi = self.bi_list.last().unwrap(); let bars_ubi = &self.bars_ubi;
if bars_ubi.last().is_some()
&& ((last_bi.direction == Direction::Up
&& bars_ubi.last().unwrap().high > last_bi.get_high())
|| (last_bi.direction == Direction::Down
&& bars_ubi.last().unwrap().low < last_bi.get_low()))
{
let merge_point = last_bi.bars[last_bi.bars.len() - 2].dt;
self.bars_ubi = last_bi.bars[..last_bi.bars.len() - 2]
.iter()
.chain(bars_ubi.iter().filter(|x| x.dt >= merge_point))
.cloned()
.collect();
self.bi_list.pop();
}
None
}
pub fn get_ubi_fxs(&self) -> Option<Vec<FX>> {
if self.bars_ubi.is_empty() {
return None;
}
Some(check_fxs(&self.bars_ubi))
}
#[allow(unused)]
fn get_ubi(&self) -> Option<UBI> {
if self.bars_ubi.is_empty() || self.bi_list.is_empty() {
return None;
}
let ubi_fxs = self.get_ubi_fxs()?;
let bars_raw = self
.bars_ubi
.iter()
.flat_map(|x| &x.elements)
.collect::<Vec<_>>();
let high_bar = bars_raw
.iter()
.max_by(|a, b| {
a.high
.partial_cmp(&b.high)
.unwrap_or(std::cmp::Ordering::Less)
})
.unwrap()
.to_owned()
.to_owned();
let low_bar = bars_raw
.iter()
.min_by(|a, b| {
a.low
.partial_cmp(&b.low)
.unwrap_or(std::cmp::Ordering::Greater)
})
.unwrap()
.to_owned()
.to_owned();
let direction = if self.bi_list.last().unwrap().direction == Direction::Down {
Direction::Up
} else {
Direction::Down
};
let fx_a = ubi_fxs.first().unwrap().to_owned();
Some(UBI {
symbol: self.symbol.clone(),
direction,
high: high_bar.high,
low: low_bar.low,
high_bar,
low_bar,
bars: self.bars_ubi.to_owned(),
raw_bars: self.bars_raw.to_owned(),
fxs: ubi_fxs,
fx_a,
})
}
}
#[cfg(feature = "python")]
#[cfg_attr(feature = "python", gen_stub_pymethods)]
#[cfg_attr(feature = "python", pymethods)]
impl CZSC {
#[new]
#[pyo3(signature = (bars_raw, max_bi_num=50))]
pub fn new_py(bars_raw: Vec<RawBar>, max_bi_num: usize) -> PyResult<Self> {
Ok(CZSC::new(bars_raw, max_bi_num))
}
#[staticmethod]
#[pyo3(signature = (df_bytes, freq, max_bi_num=50))]
pub fn from_dataframe(
df_bytes: pyo3::Bound<'_, pyo3::types::PyBytes>,
freq: Freq,
max_bi_num: usize,
) -> PyResult<Self> {
let bytes_data = df_bytes.as_bytes();
let cursor = Cursor::new(bytes_data);
let df = IpcReader::new(cursor).finish().map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Failed to read Arrow data: {e}"
))
})?;
let required_columns = [
"symbol", "dt", "open", "close", "high", "low", "vol", "amount",
];
for col in &required_columns {
if !df.get_column_names().contains(col) {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Missing required column: {col}"
)));
}
}
if df.height() == 0 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"DataFrame is empty",
));
}
let bars = format_standard_kline(df, freq).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Failed to format kline data: {e}"
))
})?;
Ok(CZSC::new(bars, max_bi_num))
}
#[getter]
fn symbol(&self) -> String {
self.symbol.to_string()
}
#[getter]
fn freq(&self) -> Freq {
self.freq
}
#[getter]
fn max_bi_num(&self) -> usize {
self.max_bi_num
}
#[getter]
fn bi_list(&self) -> Vec<BI> {
self.bi_list.to_vec()
}
#[getter]
fn bars_raw(&self) -> Vec<RawBar> {
self.bars_raw.to_vec()
}
#[getter]
fn bars_raw_df(&self, py: Python) -> PyResult<PyObject> {
let pandas = py.import("pandas")?;
let df_class = pandas.getattr("DataFrame")?;
let data: Vec<PyObject> = self
.bars_raw
.iter()
.map(|bar| -> PyResult<PyObject> {
let dict = PyDict::new(py);
dict.set_item("symbol", bar.symbol.as_ref())?;
dict.set_item("dt", create_naive_pandas_timestamp(py, bar.dt)?)?;
dict.set_item("freq", freq_to_chinese_string(bar.freq))?;
dict.set_item("id", bar.id)?;
dict.set_item("open", bar.open)?;
dict.set_item("close", bar.close)?;
dict.set_item("high", bar.high)?;
dict.set_item("low", bar.low)?;
dict.set_item("vol", bar.vol)?;
dict.set_item("amount", bar.amount)?;
Ok(dict.into())
})
.collect::<PyResult<Vec<_>>>()?;
let df = df_class.call1((data,))?;
Ok(df.into())
}
#[getter]
fn bars_ubi(&self) -> Vec<NewBar> {
self.bars_ubi.to_vec()
}
#[getter]
fn finished_bis(&self) -> Vec<BI> {
if self.bi_list.is_empty() {
return vec![];
}
if self.bars_ubi.len() < 5 {
return self.bi_list[..self.bi_list.len().saturating_sub(1)].to_vec();
}
self.bi_list.to_vec()
}
#[getter]
fn fx_list(&self) -> Vec<FX> {
self.get_fx_list().into_iter().collect()
}
#[getter]
fn cache(&self, py: Python) -> PyResult<PyObject> {
create_ordered_dict(py)
}
#[getter]
fn signals(&self, py: Python) -> PyResult<PyObject> {
create_ordered_dict(py)
}
#[getter]
fn ubi_fxs(&self) -> Vec<FX> {
self.get_ubi_fxs().unwrap_or_default()
}
#[getter]
fn ubi(&self, py: Python) -> PyResult<PyObject> {
let ubi_fxs = self.get_ubi_fxs().unwrap_or_default();
if self.bars_ubi.is_empty() || self.bi_list.is_empty() || ubi_fxs.is_empty() {
return Ok(py.None());
}
let bars_raw: Vec<RawBar> = self
.bars_ubi
.iter()
.flat_map(|x| &x.elements)
.cloned()
.collect();
if bars_raw.is_empty() {
return Ok(py.None());
}
let high_bar = bars_raw
.iter()
.max_by(|a, b| a.high.partial_cmp(&b.high).unwrap())
.unwrap()
.clone();
let low_bar = bars_raw
.iter()
.min_by(|a, b| a.low.partial_cmp(&b.low).unwrap())
.unwrap()
.clone();
let direction = if self.bi_list.last().unwrap().direction == Direction::Down {
Direction::Up
} else {
Direction::Down
};
let dict = PyDict::new(py);
dict.set_item("symbol", self.symbol.as_ref())?;
dict.set_item("direction", direction)?;
dict.set_item("high", high_bar.high)?;
dict.set_item("low", low_bar.low)?;
dict.set_item("high_bar", high_bar)?;
dict.set_item("low_bar", low_bar)?;
dict.set_item("bars", self.bars_ubi())?;
dict.set_item("raw_bars", bars_raw)?;
dict.set_item("fxs", ubi_fxs.clone())?;
dict.set_item("fx_a", ubi_fxs.first().unwrap().clone())?;
Ok(dict.into())
}
#[getter]
fn verbose(&self) -> bool {
false }
#[getter]
fn last_bi_extend(&self) -> bool {
if self.bi_list.is_empty() {
return false;
}
if self.bars_ubi.is_empty() {
return false;
}
let last_bi = &self.bi_list[self.bi_list.len() - 1];
match last_bi.direction {
Direction::Up => {
let max_high = self
.bars_ubi
.iter()
.map(|bar| bar.high)
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(0.0);
max_high > last_bi.get_high()
}
Direction::Down => {
let min_low = self
.bars_ubi
.iter()
.map(|bar| bar.low)
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(f64::MAX);
min_low < last_bi.get_low()
}
}
}
#[pyo3(signature = (_renderer=None))]
fn open_in_browser(&self, _renderer: Option<&str>) -> PyResult<String> {
Ok("Browser opening not implemented in Rust version".to_string())
}
fn to_echarts(&self) -> PyResult<String> {
Ok("ECharts export not implemented in Rust version".to_string())
}
fn to_plotly(&self) -> PyResult<String> {
Ok("Plotly export not implemented in Rust version".to_string())
}
fn update(&mut self, bar: RawBar) -> PyResult<()> {
self.update_bar(bar);
Ok(())
}
#[getter]
fn get_cache<'py>(&'py self, py: Python<'py>) -> Py<PyDict> {
{
let cache_read = self.cache.read();
if let Some(ref cached_dict) = *cache_read {
return cached_dict.clone_ref(py);
}
}
let mut cache_write = self.cache.write();
if cache_write.is_none() {
*cache_write = Some(PyDict::new(py).unbind());
}
cache_write.as_ref().unwrap().clone_ref(py)
}
#[setter]
#[gen_stub(skip)] fn set_cache(&self, dict: Py<PyDict>) {
let mut cache_write = self.cache.write();
*cache_write = Some(dict);
}
fn __repr__(&self) -> String {
format!(
"CZSC(symbol={}, freq={:?}, max_bi_num={}, bi_count={})",
self.symbol,
self.freq,
self.max_bi_num,
self.bi_list.len()
)
}
fn __reduce__(&self, py: Python) -> PyResult<PyObject> {
use pyo3::IntoPyObject;
let trimmed = CZSC::new(self.bars_raw.clone(), self.max_bi_num);
let args = (trimmed.bars_raw, self.max_bi_num).into_pyobject(py)?;
let constructor = py.get_type::<Self>();
let result = (constructor, args).into_pyobject(py)?;
Ok(result.into())
}
}
#[derive(Debug, Clone)]
pub struct UBI {
pub symbol: Symbol,
pub direction: Direction,
pub high: f64,
pub low: f64,
pub high_bar: RawBar,
pub low_bar: RawBar,
pub bars: Vec<NewBar>,
pub raw_bars: Vec<RawBar>,
pub fxs: Vec<FX>,
pub fx_a: FX,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::analyze::utils::format_standard_kline;
use crate::objects::freq::Freq;
use chrono::NaiveDateTime;
use chrono::{DateTime, Utc};
use polars::prelude::SerReader;
use polars::prelude::{CsvReader, StringChunked, StringMethods};
use std::io::Cursor;
fn example_data() -> &'static str {
const CSV_DATA: &str = r#"
dt,symbol,open,close,high,low,vol,amount
2025-01-02,002515.SZ,50.73,51.29,52.97,50.62,32900684.0,152798823.0
2025-01-03,002515.SZ,51.4,48.72,51.85,48.6,33224687.0,147184323.0
2025-01-06,002515.SZ,48.83,48.6,49.39,47.48,17419634.0,75608391.0
2025-01-07,002515.SZ,48.6,48.94,49.05,48.27,13929982.0,60500438.0
2025-01-08,002515.SZ,48.27,48.04,48.94,47.26,17697397.0,75973887.0
2025-01-09,002515.SZ,48.27,48.16,48.83,47.6,14284260.0,61391856.0
2025-01-10,002515.SZ,48.04,46.92,48.94,46.81,16080374.0,68834125.0
2025-01-13,002515.SZ,46.59,46.92,47.26,45.47,12508818.0,52037636.0
2025-01-14,002515.SZ,46.92,48.16,48.27,46.92,16407679.0,69944802.0
2025-01-15,002515.SZ,49.5,49.5,50.73,49.05,29140842.0,129502353.0
2025-01-16,002515.SZ,49.5,49.72,50.28,48.94,19124511.0,84774186.0
2025-01-17,002515.SZ,49.28,50.28,51.74,49.05,22228511.0,99754272.0
2025-01-20,002515.SZ,50.4,50.4,50.73,49.61,14908933.0,66989586.0
2025-01-21,002515.SZ,50.62,50.06,50.73,49.61,11565100.0,51612511.0
2025-01-22,002515.SZ,50.06,49.16,50.06,48.83,10889797.0,47963340.0
2025-01-23,002515.SZ,49.39,48.72,49.95,48.72,13050206.0,57522568.0
2025-01-24,002515.SZ,48.49,48.83,48.94,48.27,12042388.0,52334558.0
2025-01-27,002515.SZ,49.05,49.39,51.74,49.05,22813802.0,102357601.0
2025-02-05,002515.SZ,49.39,49.16,49.95,48.72,13525075.0,59524887.0
2025-02-06,002515.SZ,48.83,49.05,49.28,48.16,17429613.0,75782611.0
2025-02-07,002515.SZ,48.94,49.5,49.95,48.72,17447114.0,76989329.0
2025-02-10,002515.SZ,49.39,50.4,50.51,49.16,18733821.0,83810683.0
2025-02-11,002515.SZ,50.4,49.84,50.73,49.61,13189816.0,58803966.0
2025-02-12,002515.SZ,50.06,50.06,50.4,49.5,15881392.0,70692291.0
2025-02-13,002515.SZ,49.84,49.84,50.51,49.61,18048669.0,80671035.0
2025-02-14,002515.SZ,49.72,49.05,49.95,48.94,17455299.0,76786904.0
2025-02-17,002515.SZ,49.16,49.39,49.61,48.6,15791678.0,69303481.0
2025-02-18,002515.SZ,49.16,47.71,49.39,47.48,20599809.0,88885983.0
2025-02-19,002515.SZ,47.48,48.04,48.16,47.37,12911258.0,55064600.0
2025-02-20,002515.SZ,48.04,48.27,48.83,47.71,12823411.0,55267260.0
2025-02-21,002515.SZ,48.27,47.6,48.72,47.48,16547084.0,70527761.0
2025-02-24,002515.SZ,47.71,52.41,52.41,47.71,93355060.0,426873493.0
2025-02-25,002515.SZ,51.96,50.51,51.96,50.17,54431026.0,246916111.0
2025-02-26,002515.SZ,50.62,52.52,52.86,50.17,50584995.0,232883144.0
2025-02-27,002515.SZ,52.41,53.64,53.98,51.96,47142936.0,224200231.0
2025-02-28,002515.SZ,53.2,52.52,53.53,52.41,29058781.0,137329596.0
"#;
CSV_DATA
}
fn get_bars() -> Vec<RawBar> {
let cursor = Cursor::new(example_data().as_bytes());
let mut df = CsvReader::new(cursor).finish().unwrap();
let dt_col = df
.column("dt")
.unwrap()
.str()
.unwrap()
.as_datetime(
Some("%Y-%m-%d"),
polars::prelude::TimeUnit::Milliseconds,
false,
false,
None,
&StringChunked::from_iter(std::iter::once("raise")),
)
.unwrap();
df.with_column(dt_col).unwrap();
format_standard_kline(df, Freq::D).unwrap()
}
fn parse_dt(s: &str) -> DateTime<Utc> {
NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S")
.unwrap()
.and_local_timezone(Utc)
.unwrap() }
#[test]
fn test_czsc_bi_list() {
let bars = get_bars();
let c = CZSC::new(bars, 50);
let expected = [
(
"2025-01-13 00:00:00",
"2025-01-17 00:00:00",
Direction::Up,
51.74,
45.47,
),
(
"2025-01-17 00:00:00",
"2025-02-06 00:00:00",
Direction::Down,
51.74,
48.16,
),
(
"2025-02-06 00:00:00",
"2025-02-11 00:00:00",
Direction::Up,
50.73,
48.16,
),
(
"2025-02-11 00:00:00",
"2025-02-19 00:00:00",
Direction::Down,
50.73,
47.37,
),
];
assert_eq!(c.bi_list.len(), expected.len());
for (i, (bi, exp)) in c.bi_list.iter().zip(expected.iter()).enumerate() {
assert_eq!(bi.start_dt(), parse_dt(exp.0), "Index {i} sdt mismatch");
assert_eq!(bi.end_dt(), parse_dt(exp.1), "Index {i} edt mismatch");
assert_eq!(bi.direction, exp.2, "Index {i} direction mismatch");
assert!(
(bi.get_high() - exp.3).abs() < 1e-4,
"Index {i} high mismatch"
);
assert!(
(bi.get_low() - exp.4).abs() < 1e-4,
"Index {i} low mismatch"
);
}
}
#[test]
fn test_czsc_fx_list() {
let bars = get_bars();
let c = CZSC::new(bars, 50);
let expected = [
("2025-01-15 00:00:00", 50.73),
("2025-01-16 00:00:00", 48.94),
("2025-01-17 00:00:00", 51.74),
("2025-01-24 00:00:00", 48.27),
("2025-01-27 00:00:00", 51.74),
("2025-02-06 00:00:00", 48.16),
("2025-02-11 00:00:00", 50.73),
("2025-02-12 00:00:00", 49.5),
("2025-02-13 00:00:00", 50.51),
("2025-02-19 00:00:00", 47.37),
("2025-02-20 00:00:00", 48.83),
("2025-02-21 00:00:00", 47.48),
];
for (i, (fx, exp)) in c.get_fx_list().iter().zip(expected.iter()).enumerate() {
assert_eq!(fx.dt, parse_dt(exp.0), "Index {i} dt mismatch");
assert!((fx.fx - exp.1).abs() < 1e-4, "Index {i} fx mismatch");
}
}
}