batch_mode_process_response/
handle_successful_response.rs

1// ---------------- [ File: batch-mode-process-response/src/handle_successful_response.rs ]
2crate::ix!();
3
4/// Produce a *flattened* JSON payload suitable for saving to `target/`
5///
6/// * If the original assistant JSON already matches the target structure,
7///   we return it unchanged.
8/// * If it was wrapped in a `{ "fields": { … } }` envelope, we drop the
9///   envelope and return only the inner object so that subsequent
10///   `T::load_from_file()` calls (which expect a *flat* object) succeed.
11///
12/// All paths are traced for debuggability.
13#[instrument(level = "trace", skip_all)]
14fn flatten_json_for_persistence(original: &serde_json::Value) -> serde_json::Value {
15    if let Some(inner) = original.get("fields") {
16        trace!("Detected `fields` wrapper — flattening for persistence.");
17        inner.clone()
18    } else {
19        trace!("No `fields` wrapper — JSON already flat.");
20        original.clone()
21    }
22}
23
24/// ---------------------------------------------------------------------------
25/// Guarantee *header echo*:  
26/// If the JSON produced by the LLM contains a `"header"` key, we make sure its
27/// value matches `T::name()`.  If not, we patch it **before the file is written**.
28/// This keeps the rule generic – no per‑type hard‑coding is required.
29/// ---------------------------------------------------------------------------
30#[inline(always)]
31fn enforce_header_echo(seed: &(dyn Named), json: &mut serde_json::Value) {
32    let expected = seed.name();
33    match json.get_mut("header") {
34        Some(h) if h == expected.as_ref() => {} // already fine
35        Some(h) => *h = serde_json::Value::String(expected.into_owned()),
36        None => {
37            json.as_object_mut()
38                .expect("output must be object")
39                .insert("header".into(), serde_json::Value::String(expected.into_owned()));
40        }
41    }
42}
43
44#[instrument(level = "trace", skip_all)]
45pub async fn handle_successful_response<T>(
46    success_body:          &BatchSuccessResponseBody,
47    workspace:             &dyn BatchWorkspaceInterface,
48    expected_content_type: &ExpectedContentType,
49    original_seed:         &(dyn Named + Send + Sync),
50) -> Result<(), BatchSuccessResponseHandlingError>
51where
52    T: 'static + Send + Sync + Named + DeserializeOwned + GetTargetPathForAIExpansion,
53{
54    trace!(
55        "Entering handle_successful_response with success_body ID: {}",
56        success_body.id()
57    );
58    trace!(
59        "success_body => finish_reason={:?}, total_choices={}",
60        success_body.choices().get(0).map(|c| c.finish_reason()),
61        success_body.choices().len()
62    );
63
64    let choice = &success_body.choices()[0];
65    let message_content = choice.message().content();
66    trace!(
67        "Pulled first choice => finish_reason={:?}",
68        choice.finish_reason()
69    );
70
71    if *choice.finish_reason() == FinishReason::Length {
72        trace!("Detected finish_reason=Length => calling handle_finish_reason_length");
73        handle_finish_reason_length(success_body.id(), message_content).await?;
74        trace!(
75            "Returned from handle_finish_reason_length with success_body ID: {}",
76            success_body.id()
77        );
78    }
79
80    match expected_content_type {
81        ExpectedContentType::Json => {
82            trace!(
83                "ExpectedContentType::Json => about to extract/repair JSON for success_body ID: {}",
84                success_body.id()
85            );
86
87            match message_content.extract_clean_parse_json_with_repair() {
88                Ok(mut json_content) => {
89                    debug!(
90                        "JSON parse/repair succeeded for success_body ID: {}",
91                        success_body.id()
92                    );
93
94                    // ---------- DESERIALISE (with optional wrapper) ----------
95                    let typed_item: T = match deserialize_json_with_optional_fields_wrapper::<T>(&json_content) {
96                        Ok(item) => item,
97                        Err(e) => {
98                            error!("Deserialization into T failed: {:?}", e);
99                            handle_failed_json_repair(success_body.id(), message_content, workspace).await?;
100                            return Err(e.into());
101                        }
102                    };
103
104                    /* Enforce the header‑echo rule *generically*. */
105                    enforce_header_echo(original_seed, &mut json_content);
106
107                    //------------------------------------------------------------------
108
109                    trace!("Wrapping typed_item in Arc => T::name()={}", typed_item.name());
110                    let typed_item_arc: Arc<
111                        dyn GetTargetPathForAIExpansion + Send + Sync + 'static,
112                    > = Arc::new(typed_item);
113
114                    // Choose where we will write the output **before** we flatten it.
115                    let target_path = workspace.target_path(&typed_item_arc, expected_content_type);
116                    trace!("Target path computed => {:?}", target_path);
117
118                    // ---------- NEW: ALWAYS WRITE A *FLAT* OBJECT ----------
119                    let flattened_json = flatten_json_for_persistence(&json_content);
120                    let serialized_json = match serde_json::to_string_pretty(&flattened_json) {
121                        Ok(s) => {
122                            trace!(
123                                "Successfully created pretty JSON string for success_body ID: {}",
124                                success_body.id()
125                            );
126                            s
127                        }
128                        Err(e) => {
129                            error!("Re‑serialization to pretty JSON failed: {:?}", e);
130                            return Err(JsonParseError::SerdeError(e).into());
131                        }
132                    };
133                    //--------------------------------------------------------
134
135                    info!("writing JSON output to {:?}", target_path);
136                    write_to_file(&target_path, &serialized_json).await?;
137                    trace!("Successfully wrote JSON file => {:?}", target_path);
138                    trace!(
139                        "Exiting handle_successful_response with success_body ID: {}",
140                        success_body.id()
141                    );
142                    Ok(())
143                }
144                Err(e_extract) => {
145                    warn!(
146                        "JSON extraction/repair failed for success_body ID: {} with error: {:?}",
147                        success_body.id(),
148                        e_extract
149                    );
150                    let failed_id = success_body.id();
151                    trace!("Calling handle_failed_json_repair for ID={}", failed_id);
152                    handle_failed_json_repair(failed_id, message_content, workspace).await?;
153                    trace!(
154                        "Returned from handle_failed_json_repair => now returning error for ID={}",
155                        failed_id
156                    );
157                    Err(e_extract.into())
158                }
159            }
160        }
161        ExpectedContentType::PlainText => {
162            trace!(
163                "Received plain text content for request {} => length={}",
164                success_body.id(),
165                message_content.len()
166            );
167            let index = BatchIndex::from_uuid_str(success_body.id())?;
168            trace!("Parsed BatchIndex => {:?}", index);
169
170            let text_path = workspace.text_storage_path(&index);
171            info!("writing plain text output to {:?}", text_path);
172            write_to_file(&text_path, message_content.as_str()).await?;
173            trace!("Successfully wrote plain text file => {:?}", text_path);
174            trace!(
175                "Exiting handle_successful_response with success_body ID: {}",
176                success_body.id()
177            );
178            Ok(())
179        }
180        _ => {
181            todo!("Unsupported ExpectedContentType variant encountered.")
182        }
183    }
184}
185
186#[cfg(test)]
187mod handle_successful_response_tests {
188    use super::*;
189    use std::fs;
190
191    #[derive(Debug, Deserialize, Serialize, NamedItem)]
192    pub struct MockItemForSuccess {
193        pub name: String,
194    }
195
196    #[traced_test]
197    async fn test_handle_successful_response_json_failure() {
198        // This test tries to parse invalid JSON into our `MockItemForSuccess`,
199        // expecting it to fail and log a file in `failed_json_repairs_dir`.
200        trace!("===== BEGIN TEST: test_handle_successful_response_json_failure =====");
201
202        // 1) Use ephemeral default workspace (no overrides).
203        let workspace = BatchWorkspace::new_temp().await.unwrap();
204        info!("Created ephemeral workspace: {:?}", workspace);
205
206        // 2) Ensure repairs dir is empty
207        let repairs_dir = workspace.failed_json_repairs_dir();
208
209        // 3) Create a response that is *not* valid JSON
210        let invalid_msg = ChatCompletionResponseMessage {
211            role: Role::Assistant,
212            content: Some("this is not valid json at all".into()),
213            audio: None,
214            function_call: None,
215            refusal: None,
216            tool_calls: None,
217        };
218
219        let choice_fail = BatchChoiceBuilder::default()
220            .index(0_u32)
221            .finish_reason(FinishReason::Stop)
222            .logprobs(None)
223            .message(invalid_msg)
224            .build()
225            .unwrap();
226
227        let success_body = BatchSuccessResponseBodyBuilder::default()
228            .object("response".to_string())
229            .id("some-other-uuid".to_string())
230            .created(0_u64)
231            .model("test-model".to_string())
232            .choices(vec![choice_fail])
233            .usage(BatchUsage::mock())
234            .build()
235            .unwrap();
236
237        // 4) Call handle_successful_response with ExpectedContentType=Json
238        let rc = handle_successful_response::<MockItemForSuccess>(
239            &success_body,
240            workspace.as_ref(),
241            &ExpectedContentType::Json
242        ).await;
243
244        // 5) Confirm it fails
245        assert!(rc.is_err(), "We expect an error due to invalid JSON content");
246
247        // 6) Confirm the "some-other-uuid" file is in the ephemeral repairs dir
248        let repair_path = repairs_dir.join("some-other-uuid");
249        trace!("Asserting that repair file path exists: {:?}", repair_path);
250        assert!(repair_path.exists(), "A repair file must be created for invalid JSON");
251
252        trace!("===== END TEST: test_handle_successful_response_json_failure =====");
253    }
254}