1mod temporal;
2
3use std::fmt::Display;
4use std::str::FromStr;
5use std::sync::Arc;
6
7use arrow_array::cast::AsArray;
8use arrow_array::types::*;
9use arrow_array::{Array, ArrayRef, Datum, UnionArray};
10use arrow_cast::cast;
11use arrow_cast::pretty::pretty_format_columns_with_options;
12use arrow_schema::{ArrowError, DataType, Field, FieldRef, TimeUnit};
13use indexmap::IndexMap;
14use pyo3::prelude::*;
15use pyo3::types::{PyCapsule, PyList, PyTuple, PyType};
16use pyo3::{intern, IntoPyObjectExt};
17
18use crate::error::PyArrowResult;
19use crate::export::{Arro3DataType, Arro3Field, Arro3Scalar};
20use crate::ffi::to_array_pycapsules;
21use crate::scalar::temporal::{as_datetime_with_timezone, PyArrowTz};
22use crate::utils::default_repr_options;
23use crate::{PyArray, PyField};
24
25#[derive(Debug)]
27#[pyclass(module = "arro3.core._core", name = "Scalar", subclass, frozen)]
28pub struct PyScalar {
29 array: ArrayRef,
30 field: FieldRef,
31}
32
33impl PyScalar {
34 pub unsafe fn new_unchecked(array: ArrayRef, field: FieldRef) -> Self {
41 Self { array, field }
42 }
43
44 pub fn try_from_array_ref(array: ArrayRef) -> PyArrowResult<Self> {
46 let field = Field::new("", array.data_type().clone(), true);
47 Self::try_new(array, Arc::new(field))
48 }
49
50 pub fn try_new(array: ArrayRef, field: FieldRef) -> PyArrowResult<Self> {
55 let (array, field) = PyArray::try_new(array, field)?.into_inner();
57 if array.len() != 1 {
58 return Err(ArrowError::SchemaError(
59 "Expected array of length 1 for scalar".to_string(),
60 )
61 .into());
62 }
63
64 Ok(Self { array, field })
65 }
66
67 pub fn try_from_arrow_pycapsule(
69 schema_capsule: &Bound<PyCapsule>,
70 array_capsule: &Bound<PyCapsule>,
71 ) -> PyArrowResult<Self> {
72 let (array, field) =
73 PyArray::from_arrow_pycapsule(schema_capsule, array_capsule)?.into_inner();
74 Self::try_new(array, field)
75 }
76
77 pub fn array(&self) -> &ArrayRef {
79 &self.array
80 }
81
82 pub fn field(&self) -> &FieldRef {
84 &self.field
85 }
86
87 pub fn into_inner(self) -> (ArrayRef, FieldRef) {
89 (self.array, self.field)
90 }
91
92 pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
96 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
97 arro3_mod.getattr(intern!(py, "Scalar"))?.call_method1(
98 intern!(py, "from_arrow_pycapsule"),
99 self.__arrow_c_array__(py, None)?,
100 )
101 }
102
103 pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
107 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
108 let capsules = to_array_pycapsules(py, self.field.clone(), &self.array, None)?;
109 arro3_mod
110 .getattr(intern!(py, "Scalar"))?
111 .call_method1(intern!(py, "from_arrow_pycapsule"), capsules)
112 }
113}
114
115impl Display for PyScalar {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 write!(f, "arro3.core.Scalar<")?;
118 self.array.data_type().fmt(f)?;
119 writeln!(f, ">")?;
120
121 pretty_format_columns_with_options(
122 self.field.name(),
123 std::slice::from_ref(&self.array),
124 &default_repr_options(),
125 )
126 .map_err(|_| std::fmt::Error)?
127 .fmt(f)?;
128
129 Ok(())
130 }
131}
132
133impl Datum for PyScalar {
134 fn get(&self) -> (&dyn Array, bool) {
135 (self.array.as_ref(), true)
136 }
137}
138
139#[pymethods]
140impl PyScalar {
141 #[new]
142 #[pyo3(signature = (obj, /, r#type = None, *))]
143 fn init(py: Python, obj: &Bound<PyAny>, r#type: Option<PyField>) -> PyArrowResult<Self> {
144 if obj.hasattr(intern!(py, "__arrow_c_array__"))?
145 || obj.hasattr(intern!(py, "__arrow_c_stream__"))?
146 {
147 return Ok(obj.extract::<PyScalar>()?);
148 }
149
150 let obj = PyList::new(py, vec![obj])?;
151 let array = PyArray::init(py, &obj, r#type)?;
152 let (array, field) = array.into_inner();
153 Self::try_new(array, field)
154 }
155
156 #[pyo3(signature = (requested_schema=None))]
157 fn __arrow_c_array__<'py>(
158 &'py self,
159 py: Python<'py>,
160 requested_schema: Option<Bound<'py, PyCapsule>>,
161 ) -> PyArrowResult<Bound<'py, PyTuple>> {
162 to_array_pycapsules(py, self.field.clone(), &self.array, requested_schema)
163 }
164
165 fn __eq__(&self, py: Python, other: Bound<'_, PyAny>) -> PyResult<Py<PyAny>> {
166 if let Ok(other) = other.extract::<PyScalar>() {
167 let eq = self.array == other.array && self.field == other.field;
168 eq.into_py_any(py)
169 } else {
170 let self_py = self.as_py(py)?;
173 self_py.call_method1(py, intern!(py, "__eq__"), PyTuple::new(py, vec![other])?)
174 }
175 }
176
177 fn __repr__(&self) -> String {
178 self.to_string()
179 }
180
181 #[classmethod]
182 fn from_arrow(_cls: &Bound<PyType>, input: PyScalar) -> Self {
183 input
184 }
185
186 #[classmethod]
187 #[pyo3(name = "from_arrow_pycapsule")]
188 fn from_arrow_pycapsule_py(
189 _cls: &Bound<PyType>,
190 schema_capsule: &Bound<PyCapsule>,
191 array_capsule: &Bound<PyCapsule>,
192 ) -> PyArrowResult<Self> {
193 Self::try_from_arrow_pycapsule(schema_capsule, array_capsule)
194 }
195
196 pub(crate) fn as_py(&self, py: Python) -> PyArrowResult<Py<PyAny>> {
197 if self.array.is_null(0) {
198 return Ok(py.None());
199 }
200
201 let arr = self.array.as_ref();
202 let result = match self.array.data_type() {
203 DataType::Null => py.None(),
204 DataType::Boolean => arr.as_boolean().value(0).into_py_any(py)?,
205 DataType::Int8 => arr.as_primitive::<Int8Type>().value(0).into_py_any(py)?,
206 DataType::Int16 => arr.as_primitive::<Int16Type>().value(0).into_py_any(py)?,
207 DataType::Int32 => arr.as_primitive::<Int32Type>().value(0).into_py_any(py)?,
208 DataType::Int64 => arr.as_primitive::<Int64Type>().value(0).into_py_any(py)?,
209 DataType::UInt8 => arr.as_primitive::<UInt8Type>().value(0).into_py_any(py)?,
210 DataType::UInt16 => arr.as_primitive::<UInt16Type>().value(0).into_py_any(py)?,
211 DataType::UInt32 => arr.as_primitive::<UInt32Type>().value(0).into_py_any(py)?,
212 DataType::UInt64 => arr.as_primitive::<UInt64Type>().value(0).into_py_any(py)?,
213 DataType::Float16 => {
214 f32::from(arr.as_primitive::<Float16Type>().value(0)).into_py_any(py)?
215 }
216 DataType::Float32 => arr.as_primitive::<Float32Type>().value(0).into_py_any(py)?,
217 DataType::Float64 => arr.as_primitive::<Float64Type>().value(0).into_py_any(py)?,
218 DataType::Timestamp(time_unit, tz) => {
219 if let Some(tz) = tz {
220 let tz = PyArrowTz::from_str(tz)?;
221 match time_unit {
222 TimeUnit::Second => {
223 let value = arr.as_primitive::<TimestampSecondType>().value(0);
224 as_datetime_with_timezone::<TimestampSecondType>(value, tz)
225 .into_py_any(py)?
226 }
227 TimeUnit::Millisecond => {
228 let value = arr.as_primitive::<TimestampMillisecondType>().value(0);
229 as_datetime_with_timezone::<TimestampMillisecondType>(value, tz)
230 .into_py_any(py)?
231 }
232 TimeUnit::Microsecond => {
233 let value = arr.as_primitive::<TimestampMicrosecondType>().value(0);
234 as_datetime_with_timezone::<TimestampMicrosecondType>(value, tz)
235 .into_py_any(py)?
236 }
237 TimeUnit::Nanosecond => {
238 let value = arr.as_primitive::<TimestampNanosecondType>().value(0);
239 as_datetime_with_timezone::<TimestampNanosecondType>(value, tz)
240 .into_py_any(py)?
241 }
242 }
243 } else {
244 match time_unit {
245 TimeUnit::Second => arr
246 .as_primitive::<TimestampSecondType>()
247 .value_as_datetime(0)
248 .into_py_any(py)?,
249 TimeUnit::Millisecond => arr
250 .as_primitive::<TimestampMillisecondType>()
251 .value_as_datetime(0)
252 .into_py_any(py)?,
253 TimeUnit::Microsecond => arr
254 .as_primitive::<TimestampMicrosecondType>()
255 .value_as_datetime(0)
256 .into_py_any(py)?,
257 TimeUnit::Nanosecond => arr
258 .as_primitive::<TimestampNanosecondType>()
259 .value_as_datetime(0)
260 .into_py_any(py)?,
261 }
262 }
263 }
264 DataType::Date32 => arr
265 .as_primitive::<Date32Type>()
266 .value_as_date(0)
267 .into_py_any(py)?,
268 DataType::Date64 => arr
269 .as_primitive::<Date64Type>()
270 .value_as_date(0)
271 .into_py_any(py)?,
272 DataType::Time32(time_unit) => match time_unit {
273 TimeUnit::Second => arr
274 .as_primitive::<Time32SecondType>()
275 .value_as_time(0)
276 .into_py_any(py)?,
277 TimeUnit::Millisecond => arr
278 .as_primitive::<Time32MillisecondType>()
279 .value_as_time(0)
280 .into_py_any(py)?,
281 _ => unreachable!(),
282 },
283 DataType::Time64(time_unit) => match time_unit {
284 TimeUnit::Microsecond => arr
285 .as_primitive::<Time64MicrosecondType>()
286 .value_as_time(0)
287 .into_py_any(py)?,
288 TimeUnit::Nanosecond => arr
289 .as_primitive::<Time64NanosecondType>()
290 .value_as_time(0)
291 .into_py_any(py)?,
292
293 _ => unreachable!(),
294 },
295 DataType::Duration(time_unit) => match time_unit {
296 TimeUnit::Second => arr
297 .as_primitive::<DurationSecondType>()
298 .value_as_duration(0)
299 .into_py_any(py)?,
300 TimeUnit::Millisecond => arr
301 .as_primitive::<DurationMillisecondType>()
302 .value_as_duration(0)
303 .into_py_any(py)?,
304 TimeUnit::Microsecond => arr
305 .as_primitive::<DurationMicrosecondType>()
306 .value_as_duration(0)
307 .into_py_any(py)?,
308 TimeUnit::Nanosecond => arr
309 .as_primitive::<DurationNanosecondType>()
310 .value_as_duration(0)
311 .into_py_any(py)?,
312 },
313 DataType::Interval(_) => {
314 todo!("interval is not yet fully documented [ARROW-3097]")
316 }
317 DataType::Binary => arr.as_binary::<i32>().value(0).into_py_any(py)?,
318 DataType::FixedSizeBinary(_) => arr.as_fixed_size_binary().value(0).into_py_any(py)?,
319 DataType::LargeBinary => arr.as_binary::<i64>().value(0).into_py_any(py)?,
320 DataType::BinaryView => arr.as_binary_view().value(0).into_py_any(py)?,
321 DataType::Utf8 => arr.as_string::<i32>().value(0).into_py_any(py)?,
322 DataType::LargeUtf8 => arr.as_string::<i64>().value(0).into_py_any(py)?,
323 DataType::Utf8View => arr.as_string_view().value(0).into_py_any(py)?,
324 DataType::List(inner_field) => {
325 let inner_array = arr.as_list::<i32>().value(0);
326 list_values_to_py(py, inner_array, inner_field)?.into_py_any(py)?
327 }
328 DataType::LargeList(inner_field) => {
329 let inner_array = arr.as_list::<i64>().value(0);
330 list_values_to_py(py, inner_array, inner_field)?.into_py_any(py)?
331 }
332 DataType::FixedSizeList(inner_field, _list_size) => {
333 let inner_array = arr.as_fixed_size_list().value(0);
334 list_values_to_py(py, inner_array, inner_field)?.into_py_any(py)?
335 }
336 DataType::ListView(_inner_field) => {
337 todo!("as_list_view does not yet exist");
338 }
341 DataType::LargeListView(_inner_field) => {
342 todo!("as_list_view does not yet exist");
343 }
346 DataType::Struct(inner_fields) => {
347 let struct_array = arr.as_struct();
348 let mut dict_py_objects = IndexMap::with_capacity(inner_fields.len());
349 for (inner_field, column) in inner_fields.iter().zip(struct_array.columns()) {
350 let scalar =
351 unsafe { PyScalar::new_unchecked(column.clone(), inner_field.clone()) };
352 dict_py_objects.insert(inner_field.name(), scalar.as_py(py)?);
353 }
354 dict_py_objects.into_py_any(py)?
355 }
356 DataType::Union(_, _) => {
357 let array = arr.as_any().downcast_ref::<UnionArray>().unwrap();
358 let scalar = PyScalar::try_from_array_ref(array.value(0))?;
359 scalar.as_py(py)?
360 }
361 DataType::Dictionary(_, _) => {
362 let array = arr.as_any_dictionary();
363 let keys = array.keys();
364 let key = match keys.data_type() {
365 DataType::Int8 => keys.as_primitive::<Int8Type>().value(0) as usize,
366 DataType::Int16 => keys.as_primitive::<Int16Type>().value(0) as usize,
367 DataType::Int32 => keys.as_primitive::<Int32Type>().value(0) as usize,
368 DataType::Int64 => keys.as_primitive::<Int64Type>().value(0) as usize,
369 DataType::UInt8 => keys.as_primitive::<UInt8Type>().value(0) as usize,
370 DataType::UInt16 => keys.as_primitive::<UInt16Type>().value(0) as usize,
371 DataType::UInt32 => keys.as_primitive::<UInt32Type>().value(0) as usize,
372 DataType::UInt64 => keys.as_primitive::<UInt64Type>().value(0) as usize,
373 _ => unreachable!(),
376 };
377 let value = array.values().slice(key, 1);
378 PyScalar::try_from_array_ref(value)?.as_py(py)?
379 }
380 DataType::Decimal32(precision, scale) => {
381 let decimal_mod = py.import(intern!(py, "decimal"))?;
382 let decimal_class = decimal_mod.getattr(intern!(py, "Decimal"))?;
383
384 let array = arr.as_primitive::<Decimal32Type>();
385 let s = Decimal32Type::format_decimal(array.value(0), *precision, *scale);
386 decimal_class.call1((s,))?.unbind()
387 }
388 DataType::Decimal64(precision, scale) => {
389 let decimal_mod = py.import(intern!(py, "decimal"))?;
390 let decimal_class = decimal_mod.getattr(intern!(py, "Decimal"))?;
391
392 let array = arr.as_primitive::<Decimal64Type>();
393 let s = Decimal64Type::format_decimal(array.value(0), *precision, *scale);
394 decimal_class.call1((s,))?.unbind()
395 }
396 DataType::Decimal128(precision, scale) => {
397 let decimal_mod = py.import(intern!(py, "decimal"))?;
398 let decimal_class = decimal_mod.getattr(intern!(py, "Decimal"))?;
399
400 let array = arr.as_primitive::<Decimal128Type>();
401 let s = Decimal128Type::format_decimal(array.value(0), *precision, *scale);
402 decimal_class.call1((s,))?.unbind()
403 }
404 DataType::Decimal256(precision, scale) => {
405 let decimal_mod = py.import(intern!(py, "decimal"))?;
406 let decimal_class = decimal_mod.getattr(intern!(py, "Decimal"))?;
407
408 let array = arr.as_primitive::<Decimal256Type>();
409 let s = Decimal256Type::format_decimal(array.value(0), *precision, *scale);
410 decimal_class.call1((s,))?.unbind()
411 }
412 DataType::Map(_, _) => {
413 let array = arr.as_map();
414 let struct_arr = array.value(0);
415 let key_arr = struct_arr.column_by_name("key").unwrap();
416 let value_arr = struct_arr.column_by_name("value").unwrap();
417
418 let mut entries = Vec::with_capacity(struct_arr.len());
419 for i in 0..struct_arr.len() {
420 let py_key = PyScalar::try_from_array_ref(key_arr.slice(i, 1))?.as_py(py)?;
421 let py_value =
422 PyScalar::try_from_array_ref(value_arr.slice(i, 1))?.as_py(py)?;
423 entries.push(PyTuple::new(py, vec![py_key, py_value])?);
424 }
425
426 entries.into_py_any(py)?
427 }
428 DataType::RunEndEncoded(_, _) => {
429 todo!()
430 }
431 };
432 Ok(result)
433 }
434
435 fn cast(&self, target_type: PyField) -> PyArrowResult<Arro3Scalar> {
436 let new_field = target_type.into_inner();
437 let new_array = cast(&self.array, new_field.data_type())?;
438 Ok(PyScalar::try_new(new_array, new_field).unwrap().into())
439 }
440
441 #[getter]
442 #[pyo3(name = "field")]
443 fn py_field(&self) -> Arro3Field {
444 self.field.clone().into()
445 }
446
447 #[getter]
448 fn is_valid(&self) -> bool {
449 self.array.is_valid(0)
450 }
451
452 #[getter]
453 fn r#type(&self) -> Arro3DataType {
454 self.field.data_type().clone().into()
455 }
456}
457
458fn list_values_to_py(
459 py: Python,
460 inner_array: ArrayRef,
461 inner_field: &Arc<Field>,
462) -> PyArrowResult<Vec<Py<PyAny>>> {
463 let mut list_py_objects = Vec::with_capacity(inner_array.len());
464 for i in 0..inner_array.len() {
465 let scalar =
466 unsafe { PyScalar::new_unchecked(inner_array.slice(i, 1), inner_field.clone()) };
467 list_py_objects.push(scalar.as_py(py)?);
468 }
469 Ok(list_py_objects)
470}