use std::fmt::Write;
use crate::{
ClientConstants, ClientMetadata, CohortConstants, IndexSetPattern, PythonSyntax,
StructuralPattern, format_json, generate_parameterized_field, index_to_field_name,
};
pub fn generate_class_constants(output: &mut String) {
let constants = ClientConstants::collect();
writeln!(output, " VERSION = \"{}\"\n", constants.version).unwrap();
write_class_const(output, "INDEXES", &format_json(&constants.indexes));
let pool_map: std::collections::BTreeMap<String, &str> = constants
.pool_map
.iter()
.map(|(k, v)| (k.to_string(), *v))
.collect();
write_class_const(output, "POOL_ID_TO_POOL_NAME", &format_json(&pool_map));
for (name, value) in CohortConstants::all() {
write_class_const(output, name, &format_json(&value));
}
}
fn write_class_const(output: &mut String, name: &str, json: &str) {
let indented = json
.lines()
.enumerate()
.map(|(i, line)| {
if i == 0 {
format!(" {} = {}", name, line)
} else {
format!(" {}", line)
}
})
.collect::<Vec<_>>()
.join("\n");
writeln!(output, "{}\n", indented).unwrap();
}
pub fn generate_base_client(output: &mut String) {
writeln!(
output,
r#"class BrkError(Exception):
"""Custom error class for BRK client errors."""
def __init__(self, message: str, status: Optional[int] = None):
super().__init__(message)
self.status = status
class BrkClientBase:
"""Base HTTP client for making requests."""
def __init__(self, base_url: str, timeout: float = 30.0):
parsed = urlparse(base_url)
self._host = parsed.netloc
self._secure = parsed.scheme == 'https'
self._timeout = timeout
self._conn: Optional[Union[HTTPSConnection, HTTPConnection]] = None
def _connect(self) -> Union[HTTPSConnection, HTTPConnection]:
"""Get or create HTTP connection."""
if self._conn is None:
if self._secure:
self._conn = HTTPSConnection(self._host, timeout=self._timeout)
else:
self._conn = HTTPConnection(self._host, timeout=self._timeout)
return self._conn
def get(self, path: str) -> bytes:
"""Make a GET request and return raw bytes."""
try:
conn = self._connect()
conn.request("GET", path)
res = conn.getresponse()
data = res.read()
if res.status >= 400:
raise BrkError(f"HTTP error: {{res.status}}", res.status)
return data
except (ConnectionError, OSError, TimeoutError) as e:
self._conn = None
raise BrkError(str(e))
def get_json(self, path: str) -> Any:
"""Make a GET request and return JSON."""
return json.loads(self.get(path))
def get_text(self, path: str) -> str:
"""Make a GET request and return text."""
return self.get(path).decode()
def close(self) -> None:
"""Close the HTTP client."""
if self._conn:
self._conn.close()
self._conn = None
def __enter__(self) -> BrkClientBase:
return self
def __exit__(self, exc_type: Optional[type], exc_val: Optional[BaseException], exc_tb: Optional[Any]) -> None:
self.close()
def _m(acc: str, s: str) -> str:
"""Build series name with suffix."""
if not s: return acc
return f"{{acc}}_{{s}}" if acc else s
def _p(prefix: str, acc: str) -> str:
"""Build series name with prefix."""
return f"{{prefix}}_{{acc}}" if acc else prefix
"#
)
.unwrap();
}
pub fn generate_endpoint_class(output: &mut String) {
writeln!(
output,
r#"# Date conversion constants
_GENESIS = date(2009, 1, 3) # day1 0, week1 0
_DAY_ONE = date(2009, 1, 9) # day1 1 (6 day gap after genesis)
_EPOCH = datetime(2009, 1, 1, tzinfo=timezone.utc)
_DATE_INDEXES = frozenset([
'minute10', 'minute30',
'hour1', 'hour4', 'hour12',
'day1', 'day3', 'week1',
'month1', 'month3', 'month6',
'year1', 'year10',
])
def _index_to_date(index: str, i: int) -> Union[date, datetime]:
"""Convert an index value to a date/datetime for date-based indexes."""
if index == 'minute10':
return _EPOCH + timedelta(minutes=i * 10)
elif index == 'minute30':
return _EPOCH + timedelta(minutes=i * 30)
elif index == 'hour1':
return _EPOCH + timedelta(hours=i)
elif index == 'hour4':
return _EPOCH + timedelta(hours=i * 4)
elif index == 'hour12':
return _EPOCH + timedelta(hours=i * 12)
elif index == 'day1':
return _GENESIS if i == 0 else _DAY_ONE + timedelta(days=i - 1)
elif index == 'day3':
return _EPOCH.date() - timedelta(days=1) + timedelta(days=i * 3)
elif index == 'week1':
return _GENESIS + timedelta(weeks=i)
elif index == 'month1':
return date(2009 + i // 12, i % 12 + 1, 1)
elif index == 'month3':
m = i * 3
return date(2009 + m // 12, m % 12 + 1, 1)
elif index == 'month6':
m = i * 6
return date(2009 + m // 12, m % 12 + 1, 1)
elif index == 'year1':
return date(2009 + i, 1, 1)
elif index == 'year10':
return date(2009 + i * 10, 1, 1)
else:
raise ValueError(f"{{index}} is not a date-based index")
def _date_to_index(index: str, d: Union[date, datetime]) -> int:
"""Convert a date/datetime to an index value for date-based indexes.
Returns the floor index (latest index whose date is <= the given date).
For sub-day indexes (minute*, hour*), a plain date is treated as midnight UTC.
"""
if index in ('minute10', 'minute30', 'hour1', 'hour4', 'hour12'):
if isinstance(d, datetime):
dt = d if d.tzinfo else d.replace(tzinfo=timezone.utc)
else:
dt = datetime(d.year, d.month, d.day, tzinfo=timezone.utc)
secs = int((dt - _EPOCH).total_seconds())
div = {{'minute10': 600, 'minute30': 1800,
'hour1': 3600, 'hour4': 14400, 'hour12': 43200}}
return secs // div[index]
dd = d.date() if isinstance(d, datetime) else d
if index == 'day1':
if dd < _DAY_ONE:
return 0
return 1 + (dd - _DAY_ONE).days
elif index == 'day3':
return (dd - date(2008, 12, 31)).days // 3
elif index == 'week1':
return (dd - _GENESIS).days // 7
elif index == 'month1':
return (dd.year - 2009) * 12 + (dd.month - 1)
elif index == 'month3':
return (dd.year - 2009) * 4 + (dd.month - 1) // 3
elif index == 'month6':
return (dd.year - 2009) * 2 + (dd.month - 1) // 6
elif index == 'year1':
return dd.year - 2009
elif index == 'year10':
return (dd.year - 2009) // 10
else:
raise ValueError(f"{{index}} is not a date-based index")
@dataclass
class SeriesData(Generic[T]):
"""Series data with range information. Always int-indexed."""
version: int
index: Index
type: str
total: int
start: int
end: int
stamp: str
data: List[T]
@property
def is_date_based(self) -> bool:
"""Whether this series uses a date-based index."""
return self.index in _DATE_INDEXES
def indexes(self) -> List[int]:
"""Get raw index numbers."""
return list(range(self.start, self.end))
def keys(self) -> List[int]:
"""Get keys as index numbers."""
return self.indexes()
def items(self) -> List[Tuple[int, T]]:
"""Get (index, value) pairs."""
return list(zip(self.indexes(), self.data))
def to_dict(self) -> Dict[int, T]:
"""Return {{index: value}} dict."""
return dict(zip(self.indexes(), self.data))
def __iter__(self) -> Iterator[Tuple[int, T]]:
"""Iterate over (index, value) pairs."""
return iter(zip(self.indexes(), self.data))
def __len__(self) -> int:
return len(self.data)
def to_polars(self) -> pl.DataFrame:
"""Convert to Polars DataFrame with 'index' and 'value' columns."""
try:
import polars as pl # type: ignore[import-not-found]
except ImportError:
raise ImportError("polars is required: pip install polars")
return pl.DataFrame({{"index": self.indexes(), "value": self.data}})
def to_pandas(self) -> pd.DataFrame:
"""Convert to Pandas DataFrame with 'index' and 'value' columns."""
try:
import pandas as pd # type: ignore[import-not-found]
except ImportError:
raise ImportError("pandas is required: pip install pandas")
return pd.DataFrame({{"index": self.indexes(), "value": self.data}})
@dataclass
class DateSeriesData(SeriesData[T]):
"""Series data with date-based index. Extends SeriesData with date methods."""
def dates(self) -> List[Union[date, datetime]]:
"""Get dates for the index range. Returns datetime for sub-daily indexes, date for daily+."""
return [_index_to_date(self.index, i) for i in range(self.start, self.end)]
def date_items(self) -> List[Tuple[Union[date, datetime], T]]:
"""Get (date, value) pairs."""
return list(zip(self.dates(), self.data))
def to_date_dict(self) -> Dict[Union[date, datetime], T]:
"""Return {{date: value}} dict."""
return dict(zip(self.dates(), self.data))
def to_polars(self, with_dates: bool = True) -> pl.DataFrame:
"""Convert to Polars DataFrame.
Returns a DataFrame with columns:
- 'date' and 'value' if with_dates=True (default)
- 'index' and 'value' otherwise
"""
try:
import polars as pl # type: ignore[import-not-found]
except ImportError:
raise ImportError("polars is required: pip install polars")
if with_dates:
return pl.DataFrame({{"date": self.dates(), "value": self.data}})
return pl.DataFrame({{"index": self.indexes(), "value": self.data}})
def to_pandas(self, with_dates: bool = True) -> pd.DataFrame:
"""Convert to Pandas DataFrame.
Returns a DataFrame with columns:
- 'date' and 'value' if with_dates=True (default)
- 'index' and 'value' otherwise
"""
try:
import pandas as pd # type: ignore[import-not-found]
except ImportError:
raise ImportError("pandas is required: pip install pandas")
if with_dates:
return pd.DataFrame({{"date": self.dates(), "value": self.data}})
return pd.DataFrame({{"index": self.indexes(), "value": self.data}})
# Type aliases for non-generic usage
AnySeriesData = SeriesData[Any]
AnyDateSeriesData = DateSeriesData[Any]
class _EndpointConfig:
"""Shared endpoint configuration."""
client: BrkClientBase
name: str
index: Index
start: Optional[int]
end: Optional[int]
def __init__(self, client: BrkClientBase, name: str, index: Index,
start: Optional[int] = None, end: Optional[int] = None):
self.client = client
self.name = name
self.index = index
self.start = start
self.end = end
def path(self) -> str:
return f"/api/series/{{self.name}}/{{self.index}}"
def _build_path(self, format: Optional[str] = None) -> str:
params = []
if self.start is not None:
params.append(f"start={{self.start}}")
if self.end is not None:
params.append(f"end={{self.end}}")
if format is not None:
params.append(f"format={{format}}")
query = "&".join(params)
p = self.path()
return f"{{p}}?{{query}}" if query else p
def _new(self, start: Optional[int] = None, end: Optional[int] = None) -> _EndpointConfig:
return _EndpointConfig(self.client, self.name, self.index, start, end)
def get_series(self) -> SeriesData[Any]:
return SeriesData(**self.client.get_json(self._build_path()))
def get_date_series(self) -> DateSeriesData[Any]:
return DateSeriesData(**self.client.get_json(self._build_path()))
def get_csv(self) -> str:
return self.client.get_text(self._build_path(format='csv'))
class RangeBuilder(Generic[T]):
"""Builder with range specified."""
def __init__(self, config: _EndpointConfig):
self._config = config
def fetch(self) -> SeriesData[T]:
"""Fetch the range as parsed JSON."""
return self._config.get_series()
def fetch_csv(self) -> str:
"""Fetch the range as CSV string."""
return self._config.get_csv()
class SingleItemBuilder(Generic[T]):
"""Builder for single item access."""
def __init__(self, config: _EndpointConfig):
self._config = config
def fetch(self) -> SeriesData[T]:
"""Fetch the single item."""
return self._config.get_series()
def fetch_csv(self) -> str:
"""Fetch as CSV."""
return self._config.get_csv()
class SkippedBuilder(Generic[T]):
"""Builder after calling skip(n). Chain with take() to specify count."""
def __init__(self, config: _EndpointConfig):
self._config = config
def take(self, n: int) -> RangeBuilder[T]:
"""Take n items after the skipped position."""
start = self._config.start or 0
return RangeBuilder(self._config._new(start, start + n))
def fetch(self) -> SeriesData[T]:
"""Fetch from skipped position to end."""
return self._config.get_series()
def fetch_csv(self) -> str:
"""Fetch as CSV."""
return self._config.get_csv()
class DateRangeBuilder(RangeBuilder[T]):
"""Range builder that returns DateSeriesData."""
def fetch(self) -> DateSeriesData[T]:
return self._config.get_date_series()
class DateSingleItemBuilder(SingleItemBuilder[T]):
"""Single item builder that returns DateSeriesData."""
def fetch(self) -> DateSeriesData[T]:
return self._config.get_date_series()
class DateSkippedBuilder(SkippedBuilder[T]):
"""Skipped builder that returns DateSeriesData."""
def take(self, n: int) -> DateRangeBuilder[T]:
start = self._config.start or 0
return DateRangeBuilder(self._config._new(start, start + n))
def fetch(self) -> DateSeriesData[T]:
return self._config.get_date_series()
class SeriesEndpoint(Generic[T]):
"""Builder for series endpoint queries with int-based indexing.
Examples:
data = endpoint.fetch()
data = endpoint[5].fetch()
data = endpoint[:10].fetch()
data = endpoint.head(20).fetch()
data = endpoint.skip(100).take(10).fetch()
"""
def __init__(self, client: BrkClientBase, name: str, index: Index):
self._config = _EndpointConfig(client, name, index)
@overload
def __getitem__(self, key: int) -> SingleItemBuilder[T]: ...
@overload
def __getitem__(self, key: slice) -> RangeBuilder[T]: ...
def __getitem__(self, key: Union[int, slice]) -> Union[SingleItemBuilder[T], RangeBuilder[T]]:
"""Access single item or slice by integer index."""
if isinstance(key, int):
return SingleItemBuilder(self._config._new(key, key + 1))
return RangeBuilder(self._config._new(key.start, key.stop))
def head(self, n: int = 10) -> RangeBuilder[T]:
"""Get the first n items."""
return RangeBuilder(self._config._new(end=n))
def tail(self, n: int = 10) -> RangeBuilder[T]:
"""Get the last n items."""
return RangeBuilder(self._config._new(end=0) if n == 0 else self._config._new(start=-n))
def skip(self, n: int) -> SkippedBuilder[T]:
"""Skip the first n items."""
return SkippedBuilder(self._config._new(start=n))
def fetch(self) -> SeriesData[T]:
"""Fetch all data."""
return self._config.get_series()
def fetch_csv(self) -> str:
"""Fetch all data as CSV."""
return self._config.get_csv()
def path(self) -> str:
"""Get the base endpoint path."""
return self._config.path()
class DateSeriesEndpoint(Generic[T]):
"""Builder for series endpoint queries with date-based indexing.
Accepts dates in __getitem__ and returns DateSeriesData from fetch().
Examples:
data = endpoint.fetch()
data = endpoint[date(2020, 1, 1)].fetch()
data = endpoint[date(2020, 1, 1):date(2023, 1, 1)].fetch()
data = endpoint[:10].fetch()
"""
def __init__(self, client: BrkClientBase, name: str, index: Index):
self._config = _EndpointConfig(client, name, index)
@overload
def __getitem__(self, key: int) -> DateSingleItemBuilder[T]: ...
@overload
def __getitem__(self, key: datetime) -> DateSingleItemBuilder[T]: ...
@overload
def __getitem__(self, key: date) -> DateSingleItemBuilder[T]: ...
@overload
def __getitem__(self, key: slice) -> DateRangeBuilder[T]: ...
def __getitem__(self, key: Union[int, slice, date, datetime]) -> Union[DateSingleItemBuilder[T], DateRangeBuilder[T]]:
"""Access single item or slice. Accepts int, date, or datetime."""
if isinstance(key, (date, datetime)):
idx = _date_to_index(self._config.index, key)
return DateSingleItemBuilder(self._config._new(idx, idx + 1))
if isinstance(key, int):
return DateSingleItemBuilder(self._config._new(key, key + 1))
start, stop = key.start, key.stop
if isinstance(start, (date, datetime)):
start = _date_to_index(self._config.index, start)
if isinstance(stop, (date, datetime)):
stop = _date_to_index(self._config.index, stop)
return DateRangeBuilder(self._config._new(start, stop))
def head(self, n: int = 10) -> DateRangeBuilder[T]:
"""Get the first n items."""
return DateRangeBuilder(self._config._new(end=n))
def tail(self, n: int = 10) -> DateRangeBuilder[T]:
"""Get the last n items."""
return DateRangeBuilder(self._config._new(end=0) if n == 0 else self._config._new(start=-n))
def skip(self, n: int) -> DateSkippedBuilder[T]:
"""Skip the first n items."""
return DateSkippedBuilder(self._config._new(start=n))
def fetch(self) -> DateSeriesData[T]:
"""Fetch all data."""
return self._config.get_date_series()
def fetch_csv(self) -> str:
"""Fetch all data as CSV."""
return self._config.get_csv()
def path(self) -> str:
"""Get the base endpoint path."""
return self._config.path()
# Type aliases for non-generic usage
AnySeriesEndpoint = SeriesEndpoint[Any]
AnyDateSeriesEndpoint = DateSeriesEndpoint[Any]
class SeriesPattern(Protocol[T]):
"""Protocol for series patterns with different index sets."""
@property
def name(self) -> str:
"""Get the series name."""
...
def indexes(self) -> List[str]:
"""Get the list of available indexes for this series."""
...
def get(self, index: Index) -> Optional[SeriesEndpoint[T]]:
"""Get an endpoint builder for a specific index, if supported."""
...
"#
)
.unwrap();
}
pub fn generate_index_accessors(output: &mut String, patterns: &[IndexSetPattern]) {
if patterns.is_empty() {
return;
}
writeln!(output, "# Static index tuples").unwrap();
for (i, pattern) in patterns.iter().enumerate() {
write!(output, "_i{} = (", i + 1).unwrap();
for (j, index) in pattern.indexes.iter().enumerate() {
if j > 0 {
write!(output, ", ").unwrap();
}
write!(output, "'{}'", index.name()).unwrap();
}
if pattern.indexes.len() == 1 {
write!(output, ",").unwrap();
}
writeln!(output, ")").unwrap();
}
writeln!(output).unwrap();
writeln!(
output,
r#"def _ep(c: BrkClientBase, n: str, i: Index) -> SeriesEndpoint[Any]:
return SeriesEndpoint(c, n, i)
def _dep(c: BrkClientBase, n: str, i: Index) -> DateSeriesEndpoint[Any]:
return DateSeriesEndpoint(c, n, i)
"#
)
.unwrap();
writeln!(output, "# Index accessor classes\n").unwrap();
for (i, pattern) in patterns.iter().enumerate() {
let by_class_name = format!("_{}By", pattern.name);
let idx_var = format!("_i{}", i + 1);
writeln!(output, "class {}(Generic[T]):", by_class_name).unwrap();
writeln!(
output,
" def __init__(self, c: BrkClientBase, n: str): self._c, self._n = c, n"
)
.unwrap();
for index in &pattern.indexes {
let method_name = index_to_field_name(index);
let index_name = index.name();
let (builder_type, helper) = if index.is_date_based() {
("DateSeriesEndpoint", "_dep")
} else {
("SeriesEndpoint", "_ep")
};
writeln!(
output,
" def {}(self) -> {}[T]: return {}(self._c, self._n, '{}')",
method_name, builder_type, helper, index_name
)
.unwrap();
}
writeln!(output).unwrap();
writeln!(output, "class {}(Generic[T]):", pattern.name).unwrap();
writeln!(output, " by: {}[T]", by_class_name).unwrap();
writeln!(
output,
" def __init__(self, c: BrkClientBase, n: str): self._n, self.by = n, {}(c, n)",
by_class_name
)
.unwrap();
writeln!(output, " @property").unwrap();
writeln!(output, " def name(self) -> str: return self._n").unwrap();
writeln!(
output,
" def indexes(self) -> List[str]: return list({})",
idx_var
)
.unwrap();
writeln!(
output,
" def get(self, index: Index) -> Optional[SeriesEndpoint[T]]: return _ep(self.by._c, self._n, index) if index in {} else None",
idx_var
)
.unwrap();
writeln!(output).unwrap();
}
}
pub fn generate_structural_patterns(
output: &mut String,
patterns: &[StructuralPattern],
metadata: &ClientMetadata,
) {
if patterns.is_empty() {
return;
}
writeln!(output, "# Reusable structural pattern classes\n").unwrap();
for pattern in patterns {
if pattern.is_generic {
writeln!(output, "class {}(Generic[T]):", pattern.name).unwrap();
} else {
writeln!(output, "class {}:", pattern.name).unwrap();
}
writeln!(
output,
" \"\"\"Pattern struct for repeated tree structure.\"\"\""
)
.unwrap();
if !metadata.is_parameterizable(&pattern.name) {
writeln!(output, " pass\n").unwrap();
continue;
}
writeln!(output, " ").unwrap();
if pattern.is_templated() {
writeln!(
output,
" def __init__(self, client: BrkClientBase, acc: str, disc: str):"
)
.unwrap();
} else {
writeln!(
output,
" def __init__(self, client: BrkClientBase, acc: str):"
)
.unwrap();
}
writeln!(
output,
" \"\"\"Create pattern node with accumulated series name.\"\"\""
)
.unwrap();
let syntax = PythonSyntax;
for field in &pattern.fields {
generate_parameterized_field(output, &syntax, field, pattern, metadata, " ");
}
writeln!(output).unwrap();
}
}