1mod parser;
7
8use std::sync::Arc;
9
10use arrow::{
11 array::{ArrayRef, RecordBatch, StringArray},
12 datatypes::{DataType, Field, Schema, SchemaRef},
13};
14use chrono::{DateTime, Utc};
15pub use parser::{is_prose_continuation, DocTestParser};
16use serde::{Deserialize, Serialize};
17
18use crate::{ArrowDataset, Result};
19
20#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
22pub struct DocTest {
23 pub module: String,
25 pub function: String,
27 pub input: String,
29 pub expected: String,
31 pub signature: Option<String>,
33}
34
35impl DocTest {
36 #[must_use]
38 pub fn new(
39 module: impl Into<String>,
40 function: impl Into<String>,
41 input: impl Into<String>,
42 expected: impl Into<String>,
43 ) -> Self {
44 Self {
45 module: module.into(),
46 function: function.into(),
47 input: input.into(),
48 expected: expected.into(),
49 signature: None,
50 }
51 }
52
53 #[must_use]
55 pub fn with_signature(mut self, signature: impl Into<String>) -> Self {
56 self.signature = Some(signature.into());
57 self
58 }
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct DocTestCorpus {
64 pub source: String,
66 pub version: String,
68 pub extracted_at: DateTime<Utc>,
70 pub doctests: Vec<DocTest>,
72}
73
74impl DocTestCorpus {
75 #[must_use]
77 pub fn new(source: impl Into<String>, version: impl Into<String>) -> Self {
78 Self {
79 source: source.into(),
80 version: version.into(),
81 extracted_at: Utc::now(),
82 doctests: Vec::new(),
83 }
84 }
85
86 pub fn push(&mut self, doctest: DocTest) {
88 self.doctests.push(doctest);
89 }
90
91 #[must_use]
93 pub fn len(&self) -> usize {
94 self.doctests.len()
95 }
96
97 #[must_use]
99 pub fn is_empty(&self) -> bool {
100 self.doctests.is_empty()
101 }
102
103 #[must_use]
105 pub fn schema() -> SchemaRef {
106 Arc::new(Schema::new(vec![
107 Field::new("source", DataType::Utf8, false),
108 Field::new("version", DataType::Utf8, false),
109 Field::new("module", DataType::Utf8, false),
110 Field::new("function", DataType::Utf8, false),
111 Field::new("input", DataType::Utf8, false),
112 Field::new("expected", DataType::Utf8, false),
113 Field::new("signature", DataType::Utf8, true),
114 ]))
115 }
116
117 pub fn to_record_batch(&self) -> Result<RecordBatch> {
119 let len = self.doctests.len();
120
121 let source_array: ArrayRef = Arc::new(StringArray::from(vec![self.source.as_str(); len]));
122 let version_array: ArrayRef = Arc::new(StringArray::from(vec![self.version.as_str(); len]));
123 let module_array: ArrayRef = Arc::new(StringArray::from(
124 self.doctests
125 .iter()
126 .map(|d| d.module.as_str())
127 .collect::<Vec<_>>(),
128 ));
129 let function_array: ArrayRef = Arc::new(StringArray::from(
130 self.doctests
131 .iter()
132 .map(|d| d.function.as_str())
133 .collect::<Vec<_>>(),
134 ));
135 let input_array: ArrayRef = Arc::new(StringArray::from(
136 self.doctests
137 .iter()
138 .map(|d| d.input.as_str())
139 .collect::<Vec<_>>(),
140 ));
141 let expected_array: ArrayRef = Arc::new(StringArray::from(
142 self.doctests
143 .iter()
144 .map(|d| d.expected.as_str())
145 .collect::<Vec<_>>(),
146 ));
147 let signature_array: ArrayRef = Arc::new(StringArray::from(
148 self.doctests
149 .iter()
150 .map(|d| d.signature.as_deref())
151 .collect::<Vec<_>>(),
152 ));
153
154 let batch = RecordBatch::try_new(
155 Self::schema(),
156 vec![
157 source_array,
158 version_array,
159 module_array,
160 function_array,
161 input_array,
162 expected_array,
163 signature_array,
164 ],
165 )?;
166
167 Ok(batch)
168 }
169
170 pub fn to_dataset(&self) -> Result<ArrowDataset> {
172 let batch = self.to_record_batch()?;
173 ArrowDataset::from_batch(batch)
174 }
175
176 pub fn merge(&mut self, other: Self) {
178 self.doctests.extend(other.doctests);
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 #[test]
187 fn test_doctest_new() {
188 let dt = DocTest::new("os.path", "join", ">>> join('a', 'b')", "'a/b'");
189 assert_eq!(dt.module, "os.path");
190 assert_eq!(dt.function, "join");
191 assert!(dt.signature.is_none());
192 }
193
194 #[test]
195 fn test_doctest_with_signature() {
196 let dt = DocTest::new("os.path", "join", ">>> join('a', 'b')", "'a/b'")
197 .with_signature("def join(*paths) -> str");
198 assert_eq!(dt.signature, Some("def join(*paths) -> str".to_string()));
199 }
200
201 #[test]
202 fn test_corpus_basic() {
203 let mut corpus = DocTestCorpus::new("cpython", "v3.12.0");
204 assert!(corpus.is_empty());
205
206 corpus.push(DocTest::new("os", "getcwd", ">>> getcwd()", "'/home'"));
207 assert_eq!(corpus.len(), 1);
208 }
209
210 #[test]
211 fn test_corpus_to_record_batch() {
212 let mut corpus = DocTestCorpus::new("numpy", "1.26.0");
213 corpus.push(DocTest::new(
214 "numpy",
215 "array",
216 ">>> array([1,2])",
217 "array([1, 2])",
218 ));
219 corpus.push(DocTest::new(
220 "numpy",
221 "zeros",
222 ">>> zeros(3)",
223 "array([0., 0., 0.])",
224 ));
225
226 let batch = corpus.to_record_batch().expect("should create batch");
227 assert_eq!(batch.num_rows(), 2);
228 assert_eq!(batch.num_columns(), 7);
229 }
230
231 #[test]
232 fn test_corpus_schema() {
233 let schema = DocTestCorpus::schema();
234 assert_eq!(schema.fields().len(), 7);
235 assert_eq!(schema.field(0).name(), "source");
236 assert_eq!(schema.field(6).name(), "signature");
237 assert!(schema.field(6).is_nullable());
238 }
239}