batch_mode_batch_client/
download_output_file.rs1crate::ix!();
3
4#[async_trait]
5impl<E> DownloadOutputFile<E> for BatchFileTriple
6where
7 E: From<BatchDownloadError>
8 + From<std::io::Error>
9 + From<BatchMetadataError>
10 + From<OpenAIClientError>
11 + Debug,
12{
13 async fn download_output_file(
14 &mut self,
15 client: &dyn LanguageModelClientInterface<E>,
16 ) -> Result<(), E> {
17 info!("downloading batch output file");
18
19 if let Some(out_path) = &self.output() {
22 if out_path.exists() {
23 warn!(
24 "Output file already present on disk at path={:?}. \
25 Aborting to avoid overwriting.",
26 out_path
27 );
28 return Err(BatchDownloadError::OutputFileAlreadyExists {
29 triple: self.clone(),
30 }
31 .into());
32 }
33 }
34
35 let metadata_filename: PathBuf = if let Some(path) = self.associated_metadata() {
37 path.clone()
38 } else {
39 self.effective_metadata_filename()
40 };
41 debug!("Using metadata file for output: {:?}", metadata_filename);
42
43 let metadata = BatchMetadata::load_from_file(&metadata_filename).await?;
44 let output_file_id = metadata.output_file_id()?; let file_content = client.file_content(output_file_id).await?;
47
48 let output_path = self.effective_output_filename();
49 if let Some(parent) = output_path.parent() {
50 tokio::fs::create_dir_all(parent).await.ok();
51 }
52
53 std::fs::write(&output_path, file_content)?;
54 self.set_output_path(Some(output_path));
55
56 Ok(())
57 }
58}
59
60#[cfg(test)]
61mod download_output_file_tests {
62 use super::*;
63 use futures::executor::block_on;
64 use std::fs;
65 use tempfile::tempdir;
66 use tracing::{debug, error, info, trace, warn};
67
68 #[traced_test]
71 async fn test_download_output_file_ok() {
72 info!("Beginning test_download_output_file_ok");
73 trace!("Constructing mock client...");
74 let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
75 .build()
76 .unwrap();
77 debug!("Mock client: {:?}", mock_client);
78
79 let output_file_id = "some_output_file_id";
81 {
82 let mut files_guard = mock_client.files().write().unwrap();
83 files_guard.insert(output_file_id.to_string(), Bytes::from("mock output contents"));
84 }
85
86 let tmpdir = tempdir().unwrap();
88 let metadata_path = tmpdir.path().join("metadata.json");
89 let metadata = BatchMetadataBuilder::default()
90 .batch_id("batch_for_download_output_ok".to_string())
91 .input_file_id("some_input_file_id".to_string())
92 .output_file_id(Some(output_file_id.to_string()))
93 .error_file_id(None)
94 .build()
95 .unwrap();
96 metadata.save_to_file(&metadata_path).await.unwrap();
97
98 trace!("Creating BatchFileTriple with known metadata path...");
99 let mut triple = BatchFileTriple::new_for_test_with_metadata_path(metadata_path.clone());
100 triple.set_metadata_path(Some(metadata_path.clone()));
101
102 let out_path = tmpdir.path().join("output.json");
104 triple.set_output_path(Some(out_path.clone()));
105
106 trace!("Calling download_output_file...");
107 let result = triple.download_output_file(&mock_client).await;
108 debug!("Result from download_output_file: {:?}", result);
109
110 assert!(result.is_ok(), "Should succeed for a valid output file");
111 let contents = fs::read_to_string(&out_path).unwrap();
113 pretty_assert_eq!(contents, "mock output contents");
114
115 info!("test_download_output_file_ok passed");
116 }
117
118 #[traced_test]
119 async fn test_download_output_file_already_exists() {
120 info!("Beginning test_download_output_file_already_exists");
121 let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
122 .build()
123 .unwrap();
124 debug!("Mock client: {:?}", mock_client);
125
126 let tmpdir = tempdir().unwrap();
128 let metadata_path = tmpdir.path().join("metadata.json");
129 let metadata = BatchMetadataBuilder::default()
130 .batch_id("batch_exists_output")
131 .input_file_id("some_input_file_id".to_string())
132 .output_file_id(Some("already_exists_output_file_id".to_string()))
133 .error_file_id(None)
134 .build()
135 .unwrap();
136 metadata.save_to_file(&metadata_path).await.unwrap();
137
138 let mut triple = BatchFileTriple::new_for_test_with_metadata_path(metadata_path.clone());
140 triple.set_metadata_path(Some(metadata_path.clone()));
141
142 let existing_output_path = tmpdir.path().join("output.json");
144 fs::write(&existing_output_path, b"existing content").unwrap();
145 triple.set_output_path(Some(existing_output_path.clone()));
146
147 let result = triple.download_output_file(&mock_client).await;
148 debug!("Result from download_output_file: {:?}", result);
149
150 assert!(
151 result.is_err(),
152 "Should fail if output file already exists on disk"
153 );
154 info!("test_download_output_file_already_exists passed");
155 }
156
157 #[traced_test]
158 async fn test_download_output_file_missing_output_file_id() {
159 info!("Beginning test_download_output_file_missing_output_file_id");
160 let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
161 .build()
162 .unwrap();
163 debug!("Mock client: {:?}", mock_client);
164
165 let tmpdir = tempdir().unwrap();
167 let metadata_path = tmpdir.path().join("metadata.json");
168 let metadata = BatchMetadataBuilder::default()
169 .batch_id("batch_no_out_id")
170 .input_file_id("input_file_id".to_string())
171 .output_file_id(None)
172 .error_file_id(None)
173 .build()
174 .unwrap();
175 metadata.save_to_file(&metadata_path).await.unwrap();
176
177 let mut triple = BatchFileTriple::new_for_test_with_metadata_path(metadata_path.clone());
179 triple.set_metadata_path(Some(metadata_path.clone()));
180
181 let out_path = tmpdir.path().join("will_not_be_written.json");
183 triple.set_output_path(Some(out_path.clone()));
184
185 let result = triple.download_output_file(&mock_client).await;
186 debug!("Result from download_output_file: {:?}", result);
187
188 assert!(
189 result.is_err(),
190 "Should fail if output_file_id is not present in metadata"
191 );
192 info!("test_download_output_file_missing_output_file_id passed");
193 }
194
195 #[traced_test]
196 async fn test_download_output_file_client_file_not_found() {
197 info!("Beginning test_download_output_file_client_file_not_found");
198 let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
199 .build()
200 .unwrap();
201
202 let tmpdir = tempdir().unwrap();
204 let metadata_path = tmpdir.path().join("metadata.json");
205 let metadata = BatchMetadataBuilder::default()
206 .batch_id("batch_out_file_not_found")
207 .input_file_id("some_input".to_string())
208 .output_file_id(Some("out_file_that_does_not_exist".to_string()))
209 .error_file_id(None)
210 .build()
211 .unwrap();
212 metadata.save_to_file(&metadata_path).await.unwrap();
213
214 let mut triple = BatchFileTriple::new_for_test_with_metadata_path(metadata_path.clone());
215 triple.set_metadata_path(Some(metadata_path.clone()));
216
217 let out_path = tmpdir.path().join("output_file.json");
219 triple.set_output_path(Some(out_path.clone()));
220
221 let result = triple.download_output_file(&mock_client).await;
222 debug!("Result from download_output_file: {:?}", result);
223
224 assert!(
225 result.is_err(),
226 "Should fail if the mock client cannot find the output file_id"
227 );
228 info!("test_download_output_file_client_file_not_found passed");
229 }
230
231 #[traced_test]
232 async fn test_download_output_file_io_write_error() {
233 info!("Beginning test_download_output_file_io_write_error");
234 let mock_client = MockLanguageModelClientBuilder::<MockBatchClientError>::default()
235 .build()
236 .unwrap();
237
238 let output_file_id = "some_out_file_id_for_io_error";
240 {
241 let mut files_guard = mock_client.files().write().unwrap();
242 files_guard.insert(output_file_id.to_string(), Bytes::from("output content"));
243 }
244
245 let tmpdir_meta = tempdir().unwrap();
248 let tmpdir_readonly = tempdir().unwrap();
249
250 let metadata_path = tmpdir_meta.path().join("metadata.json");
251 let metadata = BatchMetadataBuilder::default()
252 .batch_id("batch_io_error")
253 .input_file_id("some_input".to_string())
254 .output_file_id(Some(output_file_id.to_string()))
255 .error_file_id(None)
256 .build()
257 .unwrap();
258 metadata.save_to_file(&metadata_path).await.unwrap();
259 debug!("Metadata saved at {:?}", metadata_path);
260
261 let mut triple = BatchFileTriple::new_for_test_with_metadata_path(metadata_path.clone());
263 triple.set_metadata_path(Some(metadata_path.clone()));
264
265 let out_path = tmpdir_readonly.path().join("output.json");
267 triple.set_output_path(Some(out_path.clone()));
268
269 let mut perms = fs::metadata(tmpdir_readonly.path()).unwrap().permissions();
271 perms.set_readonly(true);
272 fs::set_permissions(tmpdir_readonly.path(), perms).unwrap();
273
274 let result = triple.download_output_file(&mock_client).await;
275 debug!("Result from download_output_file: {:?}", result);
276
277 let mut perms = fs::metadata(tmpdir_readonly.path()).unwrap().permissions();
279 perms.set_readonly(false);
280 fs::set_permissions(tmpdir_readonly.path(), perms).unwrap();
281
282 assert!(
283 result.is_err(),
284 "Should fail with an I/O error when the directory is read-only"
285 );
286 info!("test_download_output_file_io_write_error passed");
287 }
288}