Skip to main content

alimentar/doctest/
mod.rs

1//! Python doctest extraction and corpus management.
2//!
3//! This module provides tools for extracting Python doctests from source files
4//! and converting them to Arrow/Parquet format for ML training data.
5
6mod parser;
7
8use std::sync::Arc;
9
10use arrow::{
11    array::{ArrayRef, RecordBatch, StringArray},
12    datatypes::{DataType, Field, Schema, SchemaRef},
13};
14use chrono::{DateTime, Utc};
15pub use parser::{is_prose_continuation, DocTestParser};
16use serde::{Deserialize, Serialize};
17
18use crate::{ArrowDataset, Result};
19
20/// A single extracted Python doctest.
21#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
22pub struct DocTest {
23    /// Module path (e.g., "collections.abc")
24    pub module: String,
25    /// Function/class name (e.g., "Hashable.__hash__")
26    pub function: String,
27    /// Input code (e.g., ">>> x = 5\n>>> x + 3")
28    pub input: String,
29    /// Expected output (e.g., "8")
30    pub expected: String,
31    /// Optional function signature (deferred to v2)
32    pub signature: Option<String>,
33}
34
35impl DocTest {
36    /// Create a new `DocTest`.
37    #[must_use]
38    pub fn new(
39        module: impl Into<String>,
40        function: impl Into<String>,
41        input: impl Into<String>,
42        expected: impl Into<String>,
43    ) -> Self {
44        Self {
45            module: module.into(),
46            function: function.into(),
47            input: input.into(),
48            expected: expected.into(),
49            signature: None,
50        }
51    }
52
53    /// Set the function signature.
54    #[must_use]
55    pub fn with_signature(mut self, signature: impl Into<String>) -> Self {
56        self.signature = Some(signature.into());
57        self
58    }
59}
60
61/// A corpus of extracted doctests from a Python source.
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct DocTestCorpus {
64    /// Source identifier (e.g., "cpython", "numpy", "pandas")
65    pub source: String,
66    /// Version or git SHA of the source
67    pub version: String,
68    /// When the extraction was performed
69    pub extracted_at: DateTime<Utc>,
70    /// The extracted doctests
71    pub doctests: Vec<DocTest>,
72}
73
74impl DocTestCorpus {
75    /// Create a new empty corpus.
76    #[must_use]
77    pub fn new(source: impl Into<String>, version: impl Into<String>) -> Self {
78        Self {
79            source: source.into(),
80            version: version.into(),
81            extracted_at: Utc::now(),
82            doctests: Vec::new(),
83        }
84    }
85
86    /// Add a doctest to the corpus.
87    pub fn push(&mut self, doctest: DocTest) {
88        self.doctests.push(doctest);
89    }
90
91    /// Number of doctests in the corpus.
92    #[must_use]
93    pub fn len(&self) -> usize {
94        self.doctests.len()
95    }
96
97    /// Check if corpus is empty.
98    #[must_use]
99    pub fn is_empty(&self) -> bool {
100        self.doctests.is_empty()
101    }
102
103    /// Get the Arrow schema for doctest records.
104    #[must_use]
105    pub fn schema() -> SchemaRef {
106        Arc::new(Schema::new(vec![
107            Field::new("source", DataType::Utf8, false),
108            Field::new("version", DataType::Utf8, false),
109            Field::new("module", DataType::Utf8, false),
110            Field::new("function", DataType::Utf8, false),
111            Field::new("input", DataType::Utf8, false),
112            Field::new("expected", DataType::Utf8, false),
113            Field::new("signature", DataType::Utf8, true),
114        ]))
115    }
116
117    /// Convert the corpus to an Arrow `RecordBatch`.
118    pub fn to_record_batch(&self) -> Result<RecordBatch> {
119        let len = self.doctests.len();
120
121        let source_array: ArrayRef = Arc::new(StringArray::from(vec![self.source.as_str(); len]));
122        let version_array: ArrayRef = Arc::new(StringArray::from(vec![self.version.as_str(); len]));
123        let module_array: ArrayRef = Arc::new(StringArray::from(
124            self.doctests
125                .iter()
126                .map(|d| d.module.as_str())
127                .collect::<Vec<_>>(),
128        ));
129        let function_array: ArrayRef = Arc::new(StringArray::from(
130            self.doctests
131                .iter()
132                .map(|d| d.function.as_str())
133                .collect::<Vec<_>>(),
134        ));
135        let input_array: ArrayRef = Arc::new(StringArray::from(
136            self.doctests
137                .iter()
138                .map(|d| d.input.as_str())
139                .collect::<Vec<_>>(),
140        ));
141        let expected_array: ArrayRef = Arc::new(StringArray::from(
142            self.doctests
143                .iter()
144                .map(|d| d.expected.as_str())
145                .collect::<Vec<_>>(),
146        ));
147        let signature_array: ArrayRef = Arc::new(StringArray::from(
148            self.doctests
149                .iter()
150                .map(|d| d.signature.as_deref())
151                .collect::<Vec<_>>(),
152        ));
153
154        let batch = RecordBatch::try_new(
155            Self::schema(),
156            vec![
157                source_array,
158                version_array,
159                module_array,
160                function_array,
161                input_array,
162                expected_array,
163                signature_array,
164            ],
165        )?;
166
167        Ok(batch)
168    }
169
170    /// Convert the corpus to an `ArrowDataset`.
171    pub fn to_dataset(&self) -> Result<ArrowDataset> {
172        let batch = self.to_record_batch()?;
173        ArrowDataset::from_batch(batch)
174    }
175
176    /// Merge another corpus into this one.
177    pub fn merge(&mut self, other: Self) {
178        self.doctests.extend(other.doctests);
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    #[test]
187    fn test_doctest_new() {
188        let dt = DocTest::new("os.path", "join", ">>> join('a', 'b')", "'a/b'");
189        assert_eq!(dt.module, "os.path");
190        assert_eq!(dt.function, "join");
191        assert!(dt.signature.is_none());
192    }
193
194    #[test]
195    fn test_doctest_with_signature() {
196        let dt = DocTest::new("os.path", "join", ">>> join('a', 'b')", "'a/b'")
197            .with_signature("def join(*paths) -> str");
198        assert_eq!(dt.signature, Some("def join(*paths) -> str".to_string()));
199    }
200
201    #[test]
202    fn test_corpus_basic() {
203        let mut corpus = DocTestCorpus::new("cpython", "v3.12.0");
204        assert!(corpus.is_empty());
205
206        corpus.push(DocTest::new("os", "getcwd", ">>> getcwd()", "'/home'"));
207        assert_eq!(corpus.len(), 1);
208    }
209
210    #[test]
211    fn test_corpus_to_record_batch() {
212        let mut corpus = DocTestCorpus::new("numpy", "1.26.0");
213        corpus.push(DocTest::new(
214            "numpy",
215            "array",
216            ">>> array([1,2])",
217            "array([1, 2])",
218        ));
219        corpus.push(DocTest::new(
220            "numpy",
221            "zeros",
222            ">>> zeros(3)",
223            "array([0., 0., 0.])",
224        ));
225
226        let batch = corpus.to_record_batch().expect("should create batch");
227        assert_eq!(batch.num_rows(), 2);
228        assert_eq!(batch.num_columns(), 7);
229    }
230
231    #[test]
232    fn test_corpus_schema() {
233        let schema = DocTestCorpus::schema();
234        assert_eq!(schema.fields().len(), 7);
235        assert_eq!(schema.field(0).name(), "source");
236        assert_eq!(schema.field(6).name(), "signature");
237        assert!(schema.field(6).is_nullable());
238    }
239}