1use std::default::Default;
2
3use ::popgetter_core::{
4 config::Config,
5 data_request_spec::DataRequestSpec,
6 search::{
7 CaseSensitivity, DownloadParams, MatchType, MetricId, Params, SearchConfig, SearchParams,
8 SearchText,
9 },
10 Popgetter, COL,
11};
12use polars::prelude::DataFrame;
13use pyo3::{
14 exceptions::PyValueError,
15 prelude::*,
16 types::{PyDict, PyString},
17};
18use pyo3_polars::PyDataFrame;
19use serde::de::DeserializeOwned;
20
21fn convert_py_dict<T: DeserializeOwned>(obj: &Bound<'_, PyAny>) -> PyResult<T> {
23 if let Ok(dict) = obj.downcast::<PyDict>() {
25 let json = PyModule::import_bound(dict.py(), "json")?.getattr("dumps")?;
26 let json_str: String = json.call1((dict,))?.extract()?;
27
28 serde_json::from_str::<T>(&json_str)
30 .map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()))
32 } else {
33 Err(PyErr::new::<PyValueError, _>("Argument must be a 'dict'"))
34 }
35}
36
37async fn _search(search_params: SearchParams) -> anyhow::Result<DataFrame> {
39 let search_results = Popgetter::new_with_config_and_cache(Config::default())
40 .await?
41 .search(&search_params);
42 Ok(search_results.0.select([
43 COL::METRIC_ID,
44 COL::METRIC_HUMAN_READABLE_NAME,
45 COL::METRIC_DESCRIPTION,
46 COL::METRIC_HXL_TAG,
47 COL::SOURCE_DATA_RELEASE_COLLECTION_PERIOD_START,
48 COL::COUNTRY_NAME_SHORT_EN,
49 COL::GEOMETRY_LEVEL,
50 ])?)
51}
52
53async fn _search_and_download(search_params: SearchParams) -> anyhow::Result<DataFrame> {
55 Popgetter::new_with_config_and_cache(Config::default())
56 .await?
57 .download_params(&Params {
58 search: search_params.clone(),
59 download: DownloadParams {
61 include_geoms: true,
62 region_spec: search_params.region_spec,
63 },
64 })
65 .await
66}
67
68fn get_search_params(obj: &Bound<'_, PyAny>) -> PyResult<SearchParams> {
72 if let Ok(text) = obj.downcast::<PyString>() {
73 println!(
74 "Argument is 'str', searching as text or comma-separated metric IDs: {}",
75 text
76 );
77 let search_text = SearchText {
78 text: text.to_string(),
79 ..SearchText::default()
80 };
81 return Ok(SearchParams {
82 text: vec![search_text],
83 metric_id: text
84 .to_string()
85 .split(',')
86 .map(|id_str| MetricId {
87 id: id_str.to_string(),
88 config: SearchConfig {
89 match_type: MatchType::Startswith,
90 case_sensitivity: CaseSensitivity::Insensitive,
91 },
92 })
93 .collect::<Vec<_>>(),
94 ..Default::default()
95 });
96 }
97 if let Ok(dict) = obj.downcast::<PyDict>() {
98 println!(
99 "Argument is 'dict', searching as search parameters: {}",
100 dict
101 );
102 return convert_py_dict(dict);
103 };
104 Err(PyErr::new::<PyValueError, _>(
105 "Argument must be either 'str' (text) or 'dict' (search parameters)",
106 ))
107}
108
109async fn _download_data_request_spec(data_request: DataRequestSpec) -> anyhow::Result<DataFrame> {
111 Popgetter::new_with_config_and_cache(Config::default())
112 .await?
113 .download_data_request_spec(&data_request)
114 .await
115}
116
117fn get_data_request_spec(obj: &Bound<'_, PyAny>) -> PyResult<DataRequestSpec> {
119 if let Ok(dict) = obj.downcast::<PyDict>() {
120 return convert_py_dict(dict);
121 }
122 Err(PyErr::new::<PyValueError, _>(
123 "Argument must be 'dict' (data request spec)",
124 ))
125}
126
127#[pyfunction]
130fn download_data_request(
131 #[pyo3(from_py_with = "get_data_request_spec")] data_request: DataRequestSpec,
132) -> PyResult<PyDataFrame> {
133 let rt = tokio::runtime::Builder::new_current_thread()
134 .enable_all()
135 .build()?;
136 let result = rt.block_on(_download_data_request_spec(data_request));
137 Ok(PyDataFrame(result?))
138}
139
140#[pyfunction]
143fn search(
144 #[pyo3(from_py_with = "get_search_params")] search_params: SearchParams,
145) -> PyResult<PyDataFrame> {
146 let rt = tokio::runtime::Builder::new_current_thread()
147 .enable_all()
148 .build()?;
149 let result = rt.block_on(_search(search_params));
150 Ok(PyDataFrame(result?))
151}
152
153#[pyfunction]
156fn download(
157 #[pyo3(from_py_with = "get_search_params")] search_params: SearchParams,
158) -> PyResult<PyDataFrame> {
159 let rt = tokio::runtime::Builder::new_current_thread()
160 .enable_all()
161 .build()?;
162 let result = rt.block_on(_search_and_download(search_params));
163 Ok(PyDataFrame(result?))
164}
165
166#[pymodule]
168fn popgetter(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
169 m.add_function(wrap_pyfunction!(search, m)?)?;
170 m.add_function(wrap_pyfunction!(download, m)?)?;
171 m.add_function(wrap_pyfunction!(download_data_request, m)?)?;
172 Ok(())
173}