use std::{
collections::{BTreeMap, HashMap},
sync::Arc,
};
use nautilus_core::{UnixNanos, python::to_pyvalue_err};
use nautilus_model::{
identifiers::PositionId,
position::Position,
types::{Currency, Money},
};
use pyo3::prelude::*;
use crate::{
Returns,
analyzer::{PortfolioAnalyzer, Statistic},
statistics::{
alpha::Alpha, beta_ratio::BetaRatio, expectancy::Expectancy,
information_ratio::InformationRatio, long_ratio::LongRatio, loser_avg::AvgLoser,
loser_max::MaxLoser, loser_min::MinLoser, profit_factor::ProfitFactor,
returns_avg::ReturnsAverage, returns_avg_loss::ReturnsAverageLoss,
returns_avg_win::ReturnsAverageWin, returns_volatility::ReturnsVolatility,
risk_return_ratio::RiskReturnRatio, sharpe_ratio::SharpeRatio, sortino_ratio::SortinoRatio,
tracking_error::TrackingError, treynor_ratio::TreynorRatio, win_rate::WinRate,
winner_avg::AvgWinner, winner_max::MaxWinner, winner_min::MinWinner,
},
};
#[pymethods]
#[pyo3_stub_gen::derive::gen_stub_pymethods]
impl PortfolioAnalyzer {
#[new]
#[must_use]
pub fn py_new() -> Self {
Self::new()
}
fn __repr__(&self) -> String {
format!("PortfolioAnalyzer(currencies={})", self.currencies().len())
}
#[pyo3(name = "currencies")]
fn py_currencies(&self) -> Vec<Currency> {
self.currencies().into_iter().copied().collect()
}
#[pyo3(name = "get_performance_stats_returns")]
fn py_get_performance_stats_returns(&self) -> HashMap<String, f64> {
self.get_performance_stats_returns().into_iter().collect()
}
#[pyo3(name = "get_performance_stats_position_returns")]
fn py_get_performance_stats_position_returns(&self) -> HashMap<String, f64> {
self.get_performance_stats_position_returns()
.into_iter()
.collect()
}
#[pyo3(name = "get_performance_stats_portfolio_returns")]
fn py_get_performance_stats_portfolio_returns(&self) -> HashMap<String, f64> {
self.get_performance_stats_portfolio_returns()
.into_iter()
.collect()
}
#[pyo3(name = "get_performance_stats_returns_vs_benchmark")]
fn py_get_performance_stats_returns_vs_benchmark(
&self,
benchmark: BTreeMap<u64, f64>,
) -> HashMap<String, f64> {
let benchmark: Returns = benchmark
.into_iter()
.map(|(k, v)| (UnixNanos::from(k), v))
.collect();
self.get_performance_stats_returns_vs_benchmark(&benchmark)
.into_iter()
.collect()
}
#[pyo3(name = "get_performance_stats_pnls")]
fn py_get_performance_stats_pnls(
&self,
currency: Option<&Currency>,
unrealized_pnl: Option<&Money>,
) -> PyResult<HashMap<String, f64>> {
self.get_performance_stats_pnls(currency, unrealized_pnl)
.map(|m| m.into_iter().collect())
.map_err(to_pyvalue_err)
}
#[pyo3(name = "get_performance_stats_general")]
fn py_get_performance_stats_general(&self) -> HashMap<String, f64> {
self.get_performance_stats_general().into_iter().collect()
}
#[pyo3(name = "add_position_return")]
fn py_add_position_return(&mut self, timestamp: u64, value: f64) {
self.add_position_return(UnixNanos::from(timestamp), value);
}
#[pyo3(name = "add_return")]
fn py_add_return(&mut self, timestamp: u64, value: f64) {
self.add_return(UnixNanos::from(timestamp), value);
}
#[pyo3(name = "reset")]
fn py_reset(&mut self) {
self.reset();
}
#[pyo3(name = "register_statistic")]
#[expect(clippy::needless_pass_by_value)]
fn py_register_statistic(&mut self, py: Python, statistic: Py<PyAny>) -> PyResult<()> {
let type_name = statistic
.getattr(py, "__class__")?
.getattr(py, "__name__")?
.extract::<String>(py)?;
match type_name.as_str() {
"MaxWinner" => {
let stat = statistic.extract::<MaxWinner>(py)?;
self.register_statistic(Arc::new(stat));
}
"MinWinner" => {
let stat = statistic.extract::<MinWinner>(py)?;
self.register_statistic(Arc::new(stat));
}
"AvgWinner" => {
let stat = statistic.extract::<AvgWinner>(py)?;
self.register_statistic(Arc::new(stat));
}
"MaxLoser" => {
let stat = statistic.extract::<MaxLoser>(py)?;
self.register_statistic(Arc::new(stat));
}
"MinLoser" => {
let stat = statistic.extract::<MinLoser>(py)?;
self.register_statistic(Arc::new(stat));
}
"AvgLoser" => {
let stat = statistic.extract::<AvgLoser>(py)?;
self.register_statistic(Arc::new(stat));
}
"Expectancy" => {
let stat = statistic.extract::<Expectancy>(py)?;
self.register_statistic(Arc::new(stat));
}
"WinRate" => {
let stat = statistic.extract::<WinRate>(py)?;
self.register_statistic(Arc::new(stat));
}
"ReturnsVolatility" => {
let stat = statistic.extract::<ReturnsVolatility>(py)?;
self.register_statistic(Arc::new(stat));
}
"ReturnsAverage" => {
let stat = statistic.extract::<ReturnsAverage>(py)?;
self.register_statistic(Arc::new(stat));
}
"ReturnsAverageLoss" => {
let stat = statistic.extract::<ReturnsAverageLoss>(py)?;
self.register_statistic(Arc::new(stat));
}
"ReturnsAverageWin" => {
let stat = statistic.extract::<ReturnsAverageWin>(py)?;
self.register_statistic(Arc::new(stat));
}
"SharpeRatio" => {
let stat = statistic.extract::<SharpeRatio>(py)?;
self.register_statistic(Arc::new(stat));
}
"SortinoRatio" => {
let stat = statistic.extract::<SortinoRatio>(py)?;
self.register_statistic(Arc::new(stat));
}
"ProfitFactor" => {
let stat = statistic.extract::<ProfitFactor>(py)?;
self.register_statistic(Arc::new(stat));
}
"RiskReturnRatio" => {
let stat = statistic.extract::<RiskReturnRatio>(py)?;
self.register_statistic(Arc::new(stat));
}
"LongRatio" => {
let stat = statistic.extract::<LongRatio>(py)?;
self.register_statistic(Arc::new(stat));
}
"Alpha" => {
let stat = statistic.extract::<Alpha>(py)?;
self.register_statistic(Arc::new(stat));
}
"BetaRatio" => {
let stat = statistic.extract::<BetaRatio>(py)?;
self.register_statistic(Arc::new(stat));
}
"InformationRatio" => {
let stat = statistic.extract::<InformationRatio>(py)?;
self.register_statistic(Arc::new(stat));
}
"TrackingError" => {
let stat = statistic.extract::<TrackingError>(py)?;
self.register_statistic(Arc::new(stat));
}
"TreynorRatio" => {
let stat = statistic.extract::<TreynorRatio>(py)?;
self.register_statistic(Arc::new(stat));
}
_ => {
return Err(to_pyvalue_err(format!(
"Unknown statistic type: {type_name}"
)));
}
}
Ok(())
}
#[pyo3(name = "deregister_statistic")]
#[expect(clippy::needless_pass_by_value)]
fn py_deregister_statistic(&mut self, py: Python, statistic: Py<PyAny>) -> PyResult<()> {
let type_name = statistic
.getattr(py, "__class__")?
.getattr(py, "__name__")?
.extract::<String>(py)?;
match type_name.as_str() {
"MaxWinner" => {
let stat = statistic.extract::<MaxWinner>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"MinWinner" => {
let stat = statistic.extract::<MinWinner>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"AvgWinner" => {
let stat = statistic.extract::<AvgWinner>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"MaxLoser" => {
let stat = statistic.extract::<MaxLoser>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"MinLoser" => {
let stat = statistic.extract::<MinLoser>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"AvgLoser" => {
let stat = statistic.extract::<AvgLoser>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"Expectancy" => {
let stat = statistic.extract::<Expectancy>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"WinRate" => {
let stat = statistic.extract::<WinRate>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"ReturnsVolatility" => {
let stat = statistic.extract::<ReturnsVolatility>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"ReturnsAverage" => {
let stat = statistic.extract::<ReturnsAverage>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"ReturnsAverageLoss" => {
let stat = statistic.extract::<ReturnsAverageLoss>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"ReturnsAverageWin" => {
let stat = statistic.extract::<ReturnsAverageWin>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"SharpeRatio" => {
let stat = statistic.extract::<SharpeRatio>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"SortinoRatio" => {
let stat = statistic.extract::<SortinoRatio>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"ProfitFactor" => {
let stat = statistic.extract::<ProfitFactor>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"RiskReturnRatio" => {
let stat = statistic.extract::<RiskReturnRatio>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"LongRatio" => {
let stat = statistic.extract::<LongRatio>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"Alpha" => {
let stat = statistic.extract::<Alpha>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"BetaRatio" => {
let stat = statistic.extract::<BetaRatio>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"InformationRatio" => {
let stat = statistic.extract::<InformationRatio>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"TrackingError" => {
let stat = statistic.extract::<TrackingError>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
"TreynorRatio" => {
let stat = statistic.extract::<TreynorRatio>(py)?;
self.deregister_statistic(&(Arc::new(stat) as Statistic));
}
_ => {
return Err(to_pyvalue_err(format!(
"Unknown statistic type: {type_name}"
)));
}
}
Ok(())
}
#[pyo3(name = "deregister_statistics")]
fn py_deregister_statistics(&mut self) {
self.deregister_statistics();
}
#[pyo3(name = "add_positions")]
#[expect(clippy::needless_pass_by_value)]
fn py_add_positions(&mut self, py: Python, positions: Vec<Py<PyAny>>) -> PyResult<()> {
let positions: Vec<Position> = positions
.iter()
.map(|p| {
p.getattr(py, "_mem")?
.extract::<Position>(py)
.map_err(Into::into)
})
.collect::<PyResult<Vec<Position>>>()?;
self.add_positions(&positions);
Ok(())
}
#[pyo3(name = "add_trade")]
#[allow(
clippy::trivially_copy_pass_by_ref,
reason = "matches underlying add_trade signature"
)]
fn py_add_trade(&mut self, position_id: &PositionId, ts_event: u64, realized_pnl: &Money) {
self.add_trade(position_id, UnixNanos::from(ts_event), realized_pnl);
}
#[pyo3(name = "record_trade")]
#[allow(
clippy::trivially_copy_pass_by_ref,
reason = "matches underlying record_trade signature"
)]
fn py_record_trade(&mut self, position_id: &PositionId, ts_event: u64, realized_pnl: &Money) {
self.record_trade(position_id, UnixNanos::from(ts_event), realized_pnl);
}
#[pyo3(name = "statistic")]
fn py_statistic(&self, name: &str) -> Option<String> {
self.statistic(name).map(|s| s.name())
}
#[pyo3(name = "returns")]
fn py_returns(&self, py: Python) -> PyResult<Py<PyAny>> {
let dict = pyo3::types::PyDict::new(py);
for (timestamp, value) in self.returns() {
dict.set_item(timestamp.as_u64(), value)?;
}
Ok(dict.into())
}
#[pyo3(name = "position_returns")]
fn py_position_returns(&self, py: Python) -> PyResult<Py<PyAny>> {
let dict = pyo3::types::PyDict::new(py);
for (timestamp, value) in self.position_returns() {
dict.set_item(timestamp.as_u64(), value)?;
}
Ok(dict.into())
}
#[pyo3(name = "portfolio_returns")]
fn py_portfolio_returns(&self, py: Python) -> PyResult<Py<PyAny>> {
let dict = pyo3::types::PyDict::new(py);
for (timestamp, value) in self.portfolio_returns() {
dict.set_item(timestamp.as_u64(), value)?;
}
Ok(dict.into())
}
#[pyo3(name = "realized_pnls")]
fn py_realized_pnls(&self, py: Python, currency: Option<&Currency>) -> PyResult<Py<PyAny>> {
match self.realized_pnls(currency) {
Some(pnls) => {
let list = pyo3::types::PyList::empty(py);
for (position_id, ts_event, pnl) in pnls {
list.append((position_id.to_string(), ts_event.as_u64(), pnl))?;
}
Ok(list.into())
}
None => Ok(py.None()),
}
}
#[pyo3(name = "total_pnl")]
fn py_total_pnl(
&self,
currency: Option<&Currency>,
unrealized_pnl: Option<&Money>,
) -> PyResult<f64> {
self.total_pnl(currency, unrealized_pnl)
.map_err(to_pyvalue_err)
}
#[pyo3(name = "total_pnl_percentage")]
fn py_total_pnl_percentage(
&self,
currency: Option<&Currency>,
unrealized_pnl: Option<&Money>,
) -> PyResult<f64> {
self.total_pnl_percentage(currency, unrealized_pnl)
.map_err(to_pyvalue_err)
}
#[pyo3(name = "get_stats_pnls_formatted")]
fn py_get_stats_pnls_formatted(
&self,
currency: Option<&Currency>,
unrealized_pnl: Option<&Money>,
) -> PyResult<Vec<String>> {
self.get_stats_pnls_formatted(currency, unrealized_pnl)
.map_err(to_pyvalue_err)
}
#[pyo3(name = "get_stats_returns_formatted")]
fn py_get_stats_returns_formatted(&self) -> Vec<String> {
self.get_stats_returns_formatted()
}
#[pyo3(name = "get_stats_position_returns_formatted")]
fn py_get_stats_position_returns_formatted(&self) -> Vec<String> {
self.get_stats_position_returns_formatted()
}
#[pyo3(name = "get_stats_portfolio_returns_formatted")]
fn py_get_stats_portfolio_returns_formatted(&self) -> Vec<String> {
self.get_stats_portfolio_returns_formatted()
}
#[pyo3(name = "get_stats_general_formatted")]
fn py_get_stats_general_formatted(&self) -> Vec<String> {
self.get_stats_general_formatted()
}
}