batch_mode_process_response/
handle_successful_response.rs

1// ---------------- [ File: batch-mode-process-response/src/handle_successful_response.rs ]
2crate::ix!();
3
4/// Produce a *flattened* JSON payload suitable for saving to `target/`
5///
6/// * If the original assistant JSON already matches the target structure,
7///   we return it unchanged.
8/// * If it was wrapped in a `{ "fields": { … } }` envelope, we drop the
9///   envelope and return only the inner object so that subsequent
10///   `T::load_from_file()` calls (which expect a *flat* object) succeed.
11///
12/// All paths are traced for debuggability.
13#[instrument(level = "trace", skip_all)]
14fn flatten_json_for_persistence(original: &serde_json::Value) -> serde_json::Value {
15    if let Some(inner) = original.get("fields") {
16        trace!("Detected `fields` wrapper — flattening for persistence.");
17        inner.clone()
18    } else {
19        trace!("No `fields` wrapper — JSON already flat.");
20        original.clone()
21    }
22}
23
24#[instrument(level = "trace", skip_all)]
25pub async fn handle_successful_response<T>(
26    success_body: &BatchSuccessResponseBody,
27    workspace: &dyn BatchWorkspaceInterface,
28    expected_content_type: &ExpectedContentType,
29) -> Result<(), BatchSuccessResponseHandlingError>
30where
31    T: 'static + Send + Sync + Named + DeserializeOwned + GetTargetPathForAIExpansion,
32{
33    trace!(
34        "Entering handle_successful_response with success_body ID: {}",
35        success_body.id()
36    );
37    trace!(
38        "success_body => finish_reason={:?}, total_choices={}",
39        success_body.choices().get(0).map(|c| c.finish_reason()),
40        success_body.choices().len()
41    );
42
43    let choice = &success_body.choices()[0];
44    let message_content = choice.message().content();
45    trace!(
46        "Pulled first choice => finish_reason={:?}",
47        choice.finish_reason()
48    );
49
50    if *choice.finish_reason() == FinishReason::Length {
51        trace!("Detected finish_reason=Length => calling handle_finish_reason_length");
52        handle_finish_reason_length(success_body.id(), message_content).await?;
53        trace!(
54            "Returned from handle_finish_reason_length with success_body ID: {}",
55            success_body.id()
56        );
57    }
58
59    match expected_content_type {
60        ExpectedContentType::Json => {
61            trace!(
62                "ExpectedContentType::Json => about to extract/repair JSON for success_body ID: {}",
63                success_body.id()
64            );
65
66            match message_content.extract_clean_parse_json_with_repair() {
67                Ok(json_content) => {
68                    debug!(
69                        "JSON parse/repair succeeded for success_body ID: {}",
70                        success_body.id()
71                    );
72
73                    // ---------- DESERIALISE (with optional wrapper) ----------
74                    let typed_item: T = match deserialize_json_with_optional_fields_wrapper::<T>(&json_content) {
75                        Ok(item) => item,
76                        Err(e) => {
77                            error!("Deserialization into T failed: {:?}", e);
78                            handle_failed_json_repair(success_body.id(), message_content, workspace).await?;
79                            return Err(e.into());
80                        }
81                    };
82                    //------------------------------------------------------------------
83
84                    trace!("Wrapping typed_item in Arc => T::name()={}", typed_item.name());
85                    let typed_item_arc: Arc<
86                        dyn GetTargetPathForAIExpansion + Send + Sync + 'static,
87                    > = Arc::new(typed_item);
88
89                    // Choose where we will write the output **before** we flatten it.
90                    let target_path = workspace.target_path(&typed_item_arc, expected_content_type);
91                    trace!("Target path computed => {:?}", target_path);
92
93                    // ---------- NEW: ALWAYS WRITE A *FLAT* OBJECT ----------
94                    let flattened_json = flatten_json_for_persistence(&json_content);
95                    let serialized_json = match serde_json::to_string_pretty(&flattened_json) {
96                        Ok(s) => {
97                            trace!(
98                                "Successfully created pretty JSON string for success_body ID: {}",
99                                success_body.id()
100                            );
101                            s
102                        }
103                        Err(e) => {
104                            error!("Re‑serialization to pretty JSON failed: {:?}", e);
105                            return Err(JsonParseError::SerdeError(e).into());
106                        }
107                    };
108                    //--------------------------------------------------------
109
110                    info!("writing JSON output to {:?}", target_path);
111                    write_to_file(&target_path, &serialized_json).await?;
112                    trace!("Successfully wrote JSON file => {:?}", target_path);
113                    trace!(
114                        "Exiting handle_successful_response with success_body ID: {}",
115                        success_body.id()
116                    );
117                    Ok(())
118                }
119                Err(e_extract) => {
120                    warn!(
121                        "JSON extraction/repair failed for success_body ID: {} with error: {:?}",
122                        success_body.id(),
123                        e_extract
124                    );
125                    let failed_id = success_body.id();
126                    trace!("Calling handle_failed_json_repair for ID={}", failed_id);
127                    handle_failed_json_repair(failed_id, message_content, workspace).await?;
128                    trace!(
129                        "Returned from handle_failed_json_repair => now returning error for ID={}",
130                        failed_id
131                    );
132                    Err(e_extract.into())
133                }
134            }
135        }
136        ExpectedContentType::PlainText => {
137            trace!(
138                "Received plain text content for request {} => length={}",
139                success_body.id(),
140                message_content.len()
141            );
142            let index = BatchIndex::from_uuid_str(success_body.id())?;
143            trace!("Parsed BatchIndex => {:?}", index);
144
145            let text_path = workspace.text_storage_path(&index);
146            info!("writing plain text output to {:?}", text_path);
147            write_to_file(&text_path, message_content.as_str()).await?;
148            trace!("Successfully wrote plain text file => {:?}", text_path);
149            trace!(
150                "Exiting handle_successful_response with success_body ID: {}",
151                success_body.id()
152            );
153            Ok(())
154        }
155        _ => {
156            todo!("Unsupported ExpectedContentType variant encountered.")
157        }
158    }
159}
160
161#[cfg(test)]
162mod handle_successful_response_tests {
163    use super::*;
164    use std::fs;
165
166    #[derive(Debug, Deserialize, Serialize, NamedItem)]
167    pub struct MockItemForSuccess {
168        pub name: String,
169    }
170
171    #[traced_test]
172    async fn test_handle_successful_response_json_failure() {
173        // This test tries to parse invalid JSON into our `MockItemForSuccess`,
174        // expecting it to fail and log a file in `failed_json_repairs_dir`.
175        trace!("===== BEGIN TEST: test_handle_successful_response_json_failure =====");
176
177        // 1) Use ephemeral default workspace (no overrides).
178        let workspace = BatchWorkspace::new_temp().await.unwrap();
179        info!("Created ephemeral workspace: {:?}", workspace);
180
181        // 2) Ensure repairs dir is empty
182        let repairs_dir = workspace.failed_json_repairs_dir();
183
184        // 3) Create a response that is *not* valid JSON
185        let invalid_msg = ChatCompletionResponseMessage {
186            role: Role::Assistant,
187            content: Some("this is not valid json at all".into()),
188            audio: None,
189            function_call: None,
190            refusal: None,
191            tool_calls: None,
192        };
193
194        let choice_fail = BatchChoiceBuilder::default()
195            .index(0_u32)
196            .finish_reason(FinishReason::Stop)
197            .logprobs(None)
198            .message(invalid_msg)
199            .build()
200            .unwrap();
201
202        let success_body = BatchSuccessResponseBodyBuilder::default()
203            .object("response".to_string())
204            .id("some-other-uuid".to_string())
205            .created(0_u64)
206            .model("test-model".to_string())
207            .choices(vec![choice_fail])
208            .usage(BatchUsage::mock())
209            .build()
210            .unwrap();
211
212        // 4) Call handle_successful_response with ExpectedContentType=Json
213        let rc = handle_successful_response::<MockItemForSuccess>(
214            &success_body,
215            workspace.as_ref(),
216            &ExpectedContentType::Json
217        ).await;
218
219        // 5) Confirm it fails
220        assert!(rc.is_err(), "We expect an error due to invalid JSON content");
221
222        // 6) Confirm the "some-other-uuid" file is in the ephemeral repairs dir
223        let repair_path = repairs_dir.join("some-other-uuid");
224        trace!("Asserting that repair file path exists: {:?}", repair_path);
225        assert!(repair_path.exists(), "A repair file must be created for invalid JSON");
226
227        trace!("===== END TEST: test_handle_successful_response_json_failure =====");
228    }
229}